Save and load checkpoints
Save and load checkpoints
The ability to save and resume experiments may be very useful when running long experiments. Avalanche offers a checkpointing functionality that can be used to save and restore your strategy including plugins, metrics, and loggers.
This guide will show how to plug the checkpointing functionality into the usual Avalanche main script. This only requires minor changes in the main: no changes on the strategy/plugins/... code is required! Also, make sure to check the checkpointing.py example in the repository for a ready-to-go template.
Continual learning vs classic deep learning
Resuming a continual learning experiment is not the same as resuming a classic deep learning training session. In classic training setups, the elements needed to resume an experiment are limited to i) the model weights, ii) the optimizer state, and iii) additional info such as the number of epochs/iterations so far. On the contrary, continual learning experiments need far more info to be correctly resumed:
The state of plugins, such as:
the examples saved in the replay buffer
the importance of model weights (EwC, Synaptic Intelligence)
a copy of the model (LwF)
... any many others, which are specific to each technique!
The state of metrics, as some are computed on the performance measured on previous experiences:
AMCA (Average Mean Class Accuracy) metric
Forgetting metric
Resuming experiments in Avalanche
To handle all these elements, we opted to provide an easy-to-use plugin: the CheckpointPlugin. It will take care of loading:
Strategy, including the model
Plugins
Metrics
Loggers: this includes re-opening the logs for TensoBoard, Weights & Biases, ...
State of all random number generators
In continual learning experiments, this affects the choice of replay examples and other critical elements. This is usually not needed in classic deep learning, but here may be useful!
Here, in a couple of cells, we'll show you how to use it. Remember that you can follow this guide by running it as a notebook (see below for a direct link to load it on Colab).
Let's install Avalanche:
And let us import the needed elements:
Let's proceed by defining a very vanilla Avalanche main script. Simply put, this usually comes down to defining:
Load any configuration, set seeds, etcetera
The benchmark
The model, optimizer, and loss function
Evaluation components
The list of metrics to track
The loggers
The evaluation plugin (that glues the metrics and loggers together)
The training plugins
The strategy
The train-eval loop
They do not have to be in this particular order, but this is the order followed in this guide.
To enable checkpointing, the following changes are needed:
In the very first part of the code, fix the seeds for reproducibility
The RNGManager class is used, which may be useful even in experiments in which checkpointing is not needed ;)
Instantiate the checkpointing plugin
Check if a checkpoint exists and load it
Only if not resuming from a checkpoint: create the Evaluation components, the plugins, and the strategy
Change the train/eval loop to start from the experience
Note that those changes are all properly annotated in the checkpointing.py example, which is the recommended template to follow when enabling checkpoint in a training script.
Step by step
Let's start with the first change: defining a fixed seed. This is needed to correctly re-create the benchmark object and should be the same seed used to create the checkpoint.
The RNGManager takes care of setting the seed for the following generators: Python random, NumPy, and PyTorch (both cpu and device-specific generators). In this way, you can be sure that any randomness-dependent elements in the benchmark creation procedure are identical across save/resume operations.
Let's then proceed with the usual Avalanche code. Note: nothing to change here to enable checkpointing. Here we create a SplitMNIST benchmark and instantiate a multi-task MLP model. Notice that checkpointing works fine with multi-task models wrapped using as_multitask
.
It's now time to instantiate the checkpointing plugin and load the checkpoint.
Please notice the arguments passed to the CheckpointPlugin constructor:
The first parameter is a storage object. We decided to allow the checkpointing plugin to load checkpoints from arbitrary storages. The simpler storage,
FileSystemCheckpointStorage
, will use a given directory to store the file for the current experiment (do not point multiple experiments/runs to the same directory!). However, we plan to expand this in the future to support network/cloud storages. Contributions on this are welcome :-)! Remember that theCheckpointStorage
interface is quite simple to implement in a way that best fits your needs.The device used for training. This functionality may be particularly useful in some cases: the plugin will take care of loading the checkpoint on the correct device, even if the checkpoint was created on a cuda device with a different id. This means that it can also be used to resume a CUDA checkpoint on CPU. The only caveat is that it cannot be used to load a CPU checkpoint to CUDA. In brief: CUDA -> CPU (OK), CUDA:0 -> CUDA:1 (OK), CPU -> CUDA (NO!). This will also take care of updating the device field of the strategy (and plugins) to point to the current device.
The next change is in fact quite minimal. It only requires wrapping the creation of plugins, metrics, and loggers in an "if" that checks if a checkpoint was actually loaded. If a checkpoint is loaded, the resumed strategy already contains the properly restored plugins, metrics, and loggers: it would be an error to create them.
Final change: adapt the for loop so that the training stream is iterated starting from initial_exp
. This variable was created when loading the checkpoint and it tells the next training experience to run. If no checkpoint was found, then its value will be 0.
A new checkpoint is stored at the end of each eval phase! If the program is interrupted before, all progress from the previous eval phase is lost.
Here exit_early
is a simple placeholder that you can use to experiment a bit. You may obtain a similar effect by stopping this notebook manually, restarting the kernel, and re-running all cells. You will notice that the last checkpoint will be loaded and training will resume as expected.
Usually, exit_early
should be implemented as a mechanism able to gracefully stop the process. When using SLURM or other schedulers (or even when terminating processes using Ctrl-C), you can catch termination signals and manage them properly so that the process exits after the next eval phase. However, don't worry if the process is killed abruptly: the last checkpoint will be loaded correctly once the experiment is restarted by the scheduler.
That's it for the checkpointing functionality! This is relatively new mechanism and feedbacks on this are warmly welcomed in our discussions section in the repository!
🤝 Run it on Google Colab
Last updated