A Practical Guide to Neural Network Optimization


I've been training neural networks professionally for some time now, and this is a list of "unwritten" stuff which makes your life easier when doing so. A opinionated list of miscellaneous tips and tricks for optimizing neural networks. It's a mix of programming skills, VM configuration, debugging, profiling. This post is focused on Pytorch, because that's the library that i've used most extensively. Many of these points apply to other ML libraries like JAX and Tensorflow.

This is a living document - i'll be adding to it over time.

Remote machines

Optimize for being able to run your code quickly on a new, raw virtual machine. Debugging directly on machines is much easier. Often tools like wandb won't capture stack traces or CUDA OOM errors in logs unless you've configured your code in a very precicse way.

Use github's CLI tool (gh) not git to clone repos remotely. It's much faster because you don't have to mess around with remembering how to generate ssh keys and adding them to your github account, which gets annoying if you have to do it for multiple machines. People who love programming will tell you not to use gh and to "learn git properly", but it's not interesting, so just do this instead.

Learn how to use tmux. Tmux is very useful, because it lets you create shell sessions which persist even if you log out of a remote machine. You need basically 4 commands: tmux create -s <session_name> and tmux attach -t <session_name>. You can also use tmux ls to list all your sessions. To exit a session without killing it, use ctrl+b d. To kill a session, use ctrl+b x.

Do not skimp on CPU memory or disk space if you are training large neural networks (even 50m parameters+). It is a small cost relative to the compute hardware you are using. Aim to have 16CPUs per GPU with at least 2GB memory each. Get at least a 300GB SSD.

Use the Linux watch command to make a little dashboard of helpful stuff in your terminal. For example, when I work on a remote machine, I run watch nvidia-smi, which runs nvidia-smi every one second. There are some problems that are infinitely easier to debug if you can watch the memory use of a GPU in coordination with your script running.

Similar to the above, install the htop tool (apt install htop). It is only for linux, but provides a way easier view of CPU usage when using a remote machine.


Make a public dotfiles github repo with a script that can set up a virtual machine for you. You can see an example of mine here, but below is an even simpler template:

# My bashrc
# If not running interactively, don't do anything.
[ -z "$PS1" ] && retn
# Ignore repeated commands in .bash_history
export HISTCONTROL=ignoredups
# Default editor is vim.
export EDITOR=vim
export VISUAL=vim
# Never write bytecode for python.
# Make tmux able to use 256 colours.
alias tmux="TERM=screen-256color-bce tmux"
# makes your git branch look nice
parse_git_branch() {
     git branch 2> /dev/null | sed -e '/^[^*]/d' -e 's/* \(.*\)/ (\1)/'
export PS1="\u@\h \[\033[32m\]\w\[\033[33m\]\$(parse_git_branch)\[\033[00m\] $ "

Have a dead simple script that downloads and sets up anaconda. Here is a template:

#!/usr/bin/env bash
set -e
apt install -y tmux vim git wget htop tree
# Install anaconda.
if [ ! -d $HOME/anaconda3 ];
  printf "\nInstalling anaconda\n"
  wget https://repo.anaconda.com/archive/Anaconda3-2022.05-Linux-x86_64.sh
  bash Anaconda3-2022.05-Linux-x86_64.sh
  rm Anaconda3-2022.05-Linux-x86_64.sh
  printf "\nAlready installed anaconda3? skipping python setup steps\n"

Code structure

General Principle: Parallize at the highest possible level. In a standard distributed setup, you will have N processes with a single GPU each, with K dataloader workers. Make sure that code within your dataloading processes are broadly single threaded - torch.distributed.elastic.launch does some of this for you by setting OMP_NUM_THREADS=1 and MKL_NUM_THREADS=1 for you, but for example, if your dataloading uses sckit learn to compute a nearest neighbor graph for a GNN input, ensure workers=1. Otherwise you will have N*K processes all trying to use the same CPU resources, which will slow down your training and/or cause confusing CPU crashes.

Write your training script such that it is only aware of a single process and GPU. Then, use torch.distributed.launch to launch multiple processes. This makes it much easier to debug your code, and you can use the same script for both single and multi-GPU training.

Keep your dataloading as simple as possible. Avoid the IterableDataset in Pytorch - coerce your dataset into a map style dataset, which Pytorch is designed around, so you can use the built in samplers. This avoids a lot of pain around seeding your dataloaders correctly when you have multiple workers.

Do not make model features configurable from either a CLI or a config file until you know that varying them has an effect on model performance. This is important because it's appealing to try lots of experiments, and make them all configurable in a config file, even if they don't end up being research directions that you are pursuing. This is bad, because it will slow you down later on. Instead, make a branch which modifies the code directly, and run an experiment.

Try as hard as you can not to make a "library" for training models. This is extremely difficult to do well (Pytorch do not provide a training library, which should hint to you that it is hard).

Unless you are sure that you will never do anything non-standard, do not use Pytorch Lightning, or Accelerate, or any training framework. Generally improvements to training are simple enough to implement (e.g gradient accumulation, weight averaging etc) that it's worth the time learning how they work, and implementing them directly. Non standard things include:

  • Having batches that are different sizes
  • Having more than one "stage" of training (e.g pretraining and finetuning)
  • Having more than one model being optimized at once
  • Dynamically changing what data you are training on
  • Having evaluations which take a long time

Wrap your main script entrypoint with @record so you don't loose stack traces from data loader workers.

from torch.distributed.elastic.multiprocessing.errors import record
def main():


Always check you can overfit on a single datapoint, and then a single batch of different datapoints. I cannot stress how important this is. It must be the first thing you do when you have a full training script and model ready. Make this easily configurable in your script, e.g a --debug flag, which overrides a bunch of other settings.

On Ada GPUs, enable tf32 matrix multiplications by setting torch.backends.cuda.matmul.allow_tf32 = True. Since Pytorch 1.7, this has been disabled by default as it causes precision issues in scientific computing. This sped up the models i've been working on by 4x.

When resetting your optimizer, use optimizer.zero_grad(set_to_none=True) to avoid a memory copy (grads are just assigned rather than copied into). For the smaller models I checked this with, it improved the backward step wallclock time by ~6%. Additionally, you should call zero_grad right after optimizer.backward(), so that you do not keep the gradients for the previous iteration in memory whilst you do your next forward pass.

Learn how to use einops, a short hand for many tensor-rearranging operations. It can be frustrating to read, but many high quality implementations of models rely on it, and its brevity is sometimes nice.

Unless you are specifically doing research into neural network optimization, just use Adam or AdamW. Use the fused versions.

Use the fused versions of optimizers in Pytorch. This is important. The fused versions of optimizers stack all your parameters together and make a single update using one CUDA kernel, instead of iterating in a for loop over the number of parameters. In my experience this can improve total wallclock time by ~20% and increase your device utilization substantially. If it's not available, use the foreach implementation.

Create tensors on the device you want them on: torch.randn(300, 300, device="cuda:0") is much faster than torch.randn(300, 300).cuda(). This is because the latter creates the tensor on the CPU, then copies it to the GPU. Similarly, if you are going to copy something to the CPU, copy the smallest thing possible - e.g if you are computing a metric, you want to compute the metric on the GPU and pass a scalar back to the CPU, rather than doing the reduction on the CPU.

This is controvertial, but don't over index on replicability. Replicability (Exactly reproducing a run with the same exact floating point operations) with GPUs is difficult to achieve, and using a fixed seed does not mean that your experiments are any "better". If a research idea works, it does not just work for a single seed. Equally, you may miss models which are stochastically good - RL models which use Proximal Policy Optimization only converge 20% of the time, but in that 20%, they are much better than standard policy gradients. Does that mean they don't work? No. But if you'd been using a fixed seed, and it happened to be one of the seeds which didn't work, you'd miss this.

Use a profiler every time you run your code, especially if you use wandb. The profiler can run for the first N steps of your training run only, so it's effectively free. It's an excellent way to look at a full trace.

import torch
import wandb
def _on_trace_ready(prof: torch.profiler.profile):
    run = wandb.run
    if run is not None:
        wandb.save(f"{run.dir}/*.pt.trace.json", policy="now")  # type: ignore
def get_profiler() -> torch.profiler.profile:
    schedule = {
        "skip_first": 10,  # Do nothing for the first few steps
        "wait": 5,  # Wait for a few steps, but viewable in trace
        "warmup": 2,  # trace, but discard results to burn in overhead from profiler
        "active": 1,  # Active tracing. Change this to capture more frames.
        "repeat": 1,  # only do this once during training
    profiler = torch.profiler.profile(
    return profiler
# Use:
profiler.step()  # Do this at the end of each training step
profiler.stop() # Fully stop the profiler after a few steps

Now you've spent the time creating the profiles for your models, read this excellent wandb article explaining how to interpret these traces. It is not as straightforward as a typical flamegraph, because of the asyncronous nature of GPUs. Learn how to identify CPU bottlenecks, slow GPU operations and memory issues from these graphs.


Most nans in deep learning happen because of 4 reasons:

  1. The number one reason your model outputs a nan value is that you are dividing by zero somewhere in your code. This can often happen if you mask a value, and then take an average with respect to the number of masked elements - if that mask is all zeros, you will divide by zero, causing a nan.
  2. You accidentally took the log of zero. This is not defined. If you have any log values (including common operations like log_softmax), check that you did not pass a zero valued tensor into them.
  3. Your input data contains weird values, or it is not what you expect. Write a test to check the literal values of your data.
  4. If your model's loss starts increasing rapidly and then nans, this is typically due to overflow and is caused by having a learning rate which is too high.

Pytorch Gotchas

If your models memory use grows on every forward and backward pass, you are storing a reference to a variable which is from a previous batches' computation graph. This is common when storing metrics - you need to call .detach() to remove it from the computation graph, so that it can be garbage collected. Until you do this, Pytorch must keep all the gradients and forward pass values for everything used to compute that metric - because if you then used that value further, Pytorch needs to be able to compute gradients with respect to it.

Pytorch's default weight initializations are a bit funky. It is safest to have these explicitly in your code, especially for conv nets. Read this blog post to understand this better.

If you use time.time() inside an optimization loop, the numbers will often be completely wrong, especially if you are using distributed training (because in DDP, the processes will block to wait for each other, and in general the exection of Pytorch code is queuing kernels to be launched in the future, not actually running them). Instead, record time at a lower level of granularity (like how long an epoch takes) and divide by the number of batches you've seen.

When training in half-precision, it's important that LayerNorm-like layers (BatchNorm, GroupNorm etc) do not operate in half precision.

The random seeding in Pytorch's DistributedSampler is very confusing. To use it correctly, you must call sampler.set_epoch after each epoch. Unfortunately, this breaks down is if your training loop does not fully iterate over the dataset before running an evaluation (for example it is too big, or it is too small, so you repeat it multiple times). In these cases, you need a sampler which will give you the same set of indices for a given worker, but randomly shuffle them (the set of indices assigned to a specific worker then changes per epoch, when you call set_epoch). The code below does this.

from torch.utils.data import DistributedSampler
class DistributedRandomSampler(DistributedSampler):
    PyTorch's DistributedSampler splits the dataset indices
    into a subset for each process, but then iterates over
    these subset indices in the same order each time it is called.
    We would like to randomly shuffle the indices each time.
    NOTE: This class still requires that `set_epoch` is called
    each epoch to ensure that the random seed is updated for
    *splitting the indices across processes*.
    def __iter__(self) -> Iterator:
        """Iterate over indices in a random order."""
        indices = list(super().__iter__())
        # shuffle indices, explictly using the base seed
        # so that each call is random.
        if self.shuffle:
            shuff = torch.randperm(len(indices)).tolist()
        return (indices[i] for i in shuff)

Distributed Training

Learn how Pytorch DDP works and understand how to pass information across processes in memory. Here is an example:

  # file: test_ddp.py
  # run: torchrun --nproc_per_node=N test_ddp.py
  import torch.distributed as dist
  x = torch.randn(2,2)
  # Setup output, one for each process
  output = [None for _ in range(dist.get_world_size())]
      output if dist.get_rank() == 0 else None,
  # In rank 0, this is a list of tensors. In the other ranks, it's a list of None.

Within your script, use a single worker for logging by inspecting torch.distributed.get_rank(). Make a file for helpers like this. For example, you can use if torch.distributed.get_rank() == 0 to only log from the first process.

Make liberal use of dist.barrier(). What this does is forces all processes to wait at certain points, until all workers are at that point. This is done automatically for model.forward() and loss.backward(), so gradients can be averaged across processes - but often, in evaulation code or metrics computation, you want all the workers to have finished their batches before you start computing metrics across processes. This is particularly important if your batches can be different sizes (e.g dynamic batches), because this means your workers may take different lengths of time to compute a fixed number of (different) batches.

If you have a pytorch training job which hangs, py-spy is an extremely useful tool which lets you inspect a currently running python process. It produces interactively updated flamegraphs of the function names running in a given process. This is my go-to tool for debugging dataloading processes which hang. py-spy top --pid <my process id>

Further Reading