https://github.com/super-dainiu/play_fx
Playing torch.FX for multiple usages
https://github.com/super-dainiu/play_fx
Last synced: 2 months ago
JSON representation
Playing torch.FX for multiple usages
- Host: GitHub
- URL: https://github.com/super-dainiu/play_fx
- Owner: super-dainiu
- Created: 2023-02-07T08:26:20.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2023-02-08T08:15:50.000Z (over 2 years ago)
- Last Synced: 2025-01-29T16:43:14.829Z (4 months ago)
- Language: Python
- Size: 14.6 KB
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# play_fx
Playing torch.FX for multiple usages## FX Graph
See https://pytorch.org/docs/stable/fx.html. This is clear enough.## Orig FX passes
```python
class TwoLayerNet(torch.nn.Module):
def __init__(self, D_in, H, D_out):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred
N, D_in, H, D_out = 64, 1000, 100, 10
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)
model = TwoLayerNet(D_in, H, D_out)
gm = torch.fx.symbolic_trace(model)
sample_input = torch.randn(50, D_in)
ShapeProp(gm).propagate(sample_input)
for node in gm.graph.nodes:
print(node.name, node.meta['tensor_meta'].dtype,
node.meta['tensor_meta'].shape)
```