Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/kentonishi/torch-mutable-modules
Use in-place and assignment operations on PyTorch module parameters with support for autograd.
https://github.com/kentonishi/torch-mutable-modules
pytorch torch
Last synced: about 1 month ago
JSON representation
Use in-place and assignment operations on PyTorch module parameters with support for autograd.
- Host: GitHub
- URL: https://github.com/kentonishi/torch-mutable-modules
- Owner: KentoNishi
- License: mit
- Created: 2022-01-30T21:44:03.000Z (almost 3 years ago)
- Default Branch: master
- Last Pushed: 2022-06-06T18:10:20.000Z (over 2 years ago)
- Last Synced: 2024-09-14T13:08:57.239Z (2 months ago)
- Topics: pytorch, torch
- Language: Python
- Homepage: https://pypi.org/project/torch-mutable-modules/
- Size: 28.3 KB
- Stars: 7
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: readme.md
- License: LICENSE
Awesome Lists containing this project
README
# Torch Mutable Modules
Use in-place and assignment operations on PyTorch module parameters with support for autograd.
[![Publish to PyPI](https://github.com/KentoNishi/torch-mutable-modules/actions/workflows/publish.yaml/badge.svg)](https://github.com/KentoNishi/torch-mutable-modules/actions/workflows/publish.yaml)
[![Run tests](https://github.com/KentoNishi/torch-mutable-modules/actions/workflows/test.yaml/badge.svg)](https://github.com/KentoNishi/torch-mutable-modules/actions/workflows/test.yaml)
[![PyPI version](https://img.shields.io/pypi/v/torch-mutable-modules.svg?style=flat)](https://pypi.org/project/torch-mutable-modules/)
[![Number of downloads from PyPI per month](https://img.shields.io/pypi/dm/torch-mutable-modules.svg?style=flat)](https://pypi.org/project/torch-mutable-modules/)
![Python version support](https://img.shields.io/pypi/pyversions/torch-mutable-modules)
[![Code Style: Black](https://img.shields.io/badge/code%20style-black-black.svg)](https://github.com/ambv/black)## Why does this exist?
PyTorch does not allow in-place operations on module parameters (usually desirable):
```python
linear_layer = torch.nn.Linear(1, 1)
linear_layer.weight.data += 69
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Valid, but will NOT store grad_fn=
linear_layer.weight += 420
# ^^^^^^^^^^^^^^^^^^^^^^^^
# RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
```In some cases, however, it is useful to be able to modify module parameters in-place. For example, if we have a neural network (`net_1`) that predicts the parameter values to another neural network (`net_2`), we need to be able to modify the weights of `net_2` in-place and backpropagate the gradients to `net_1`.
```python
# create a parameter predictor network (net_1)
net_1 = torch.nn.Linear(1, 2)
# create an optimizer for net_1
optimizer_1 = torch.optim.SGD(net_1.parameters(), lr=0.01)# predict the weights and biases of net_2 using net_1
p_weight_and_bias = net_1(input_0).unsqueeze(2)
p_weight, p_bias = p_weight_and_bias[:, 0], p_weight_and_bias[:, 1]# create a mutable network (net_2)
net_2 = to_mutable_module(torch.nn.Linear(1, 1))# hot-swap the weights and biases of net_2 with the predicted values
net_2.weight = p_weight
net_2.bias = p_bias# compute the output and backpropagate the gradients to net_1
output = net_2(input_1)
loss = criterion(output, label)
loss.backward()
optimizer.step()
```This library provides a way to easily convert PyTorch modules into mutable modules with the `to_mutable_module` function.
## Installation
You can install `torch-mutable-modules` from [PyPI](https://pypi.org/project/torch-mutable-modules/).```bash
pip install torch-mutable-modules
```To upgrade an existing installation of `torch-mutable-modules`, use the following command:
```bash
pip install --upgrade --no-cache-dir torch-mutable-modules
```## Importing
You can use wildcard imports or import specific functions directly:
```python
# import all functions
from torch_mutable_modules import *# ... or import the function manually
from torch_mutable_modules import to_mutable_module
```## Usage
To convert an existing PyTorch module into a mutable module, use the `to_mutable_module` function:
```python
converted_module = to_mutable_module(
torch.nn.Linear(1, 1)
) # type of converted_module is still torch.nn.Linearconverted_module.weight *= 0
convreted_module.weight += 69
convreted_module.weight # tensor([[69.]], grad_fn=)
```You can also declare your own PyTorch module classes as mutable, and all child modules will be recursively converted into mutable modules:
```python
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 1)
def forward(self, x):
return self.linear(x)my_module = to_mutable_module(MyModule())
my_module.linear.weight *= 0
my_module.linear.weight += 69
my_module.linear.weight # tensor([[69.]], grad_fn=)
```### Usage with CUDA
To create a module on the GPU, simply pass a PyTorch module that is already on the GPU to the `to_mutable_module` function:
```python
converted_module = to_mutable_module(
torch.nn.Linear(1, 1).cuda()
) # converted_module is now a mutable module on the GPU
```Moving a module to the GPU with `.to()` and `.cuda()` after instanciation is ***NOT*** supported. Instead, hot-swap the module parameter tensors with their CUDA counterparts.
```python
# both of these are valid
converted_module.weight = converted_module.weight.cuda()
converted_module.bias = converted_module.bias.to("cuda")
```### Detailed examples
Please check out [example.py](./example.py) to see more detailed example usages of the `to_mutable_module` function.
## Contributing
Please feel free to submit issues or pull requests!