Get more out of your backward pass
BackPACK is a library built on top of PyTorch to make it easy to extract more information from a backward pass. Some of the things you can compute:
"""
Compute the gradient with Pytorch
"""
from torch.nn import CrossEntropyLoss, Linear
from backpack.utils.examples import load_one_batch_mnist
X, y = load_one_batch_mnist(flat=True)
model = Linear(784, 10)
lossfunc = CrossEntropyLoss()
loss = lossfunc(model(X), y)
loss.backward()
for param in model.parameters():
print(param.grad)
Install with
pip install backpack-for-pytorch
If you use BackPACK in your research, please cite
@inproceedings{dangel2020backpack,
title = {BackPACK: Packing more into Backprop},
author = {Felix Dangel and Frederik Kunstner and Philipp Hennig},
booktitle = {International Conference on Learning Representations},
year = {2020},
url = {https://openreview.net/forum?id=BJlrF24twB}
}