https://github.com/t-k-233/diffusion-benchmark
https://github.com/t-k-233/diffusion-benchmark
Last synced: 2 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/t-k-233/diffusion-benchmark
- Owner: T-K-233
- Created: 2023-12-23T20:21:35.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2024-02-27T02:53:12.000Z (over 1 year ago)
- Last Synced: 2025-04-06T09:26:52.773Z (6 months ago)
- Language: Python
- Size: 226 MB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Torch Exporter
## Installation
Install TensorRT environment from [here](https://notes.tk233.xyz/ml-rl/setting-up-tensorrt-environment-on-ubuntu-2x.04).
```bash
pip install -e .
```## Example Usage
```python
import torch
from torchconverter import TorchTRTfrom model import TransformerForDiffusion
# torch model
model = TransformerForDiffusion(device="cpu")# sample input
sample = torch.rand((1, 16, 12), dtype=torch.float32)
timestep = torch.rand((1, ), dtype=torch.float32)
cond = torch.rand((1, 8, 42), dtype=torch.float32)# convert to TensorRT
trt_model = TorchTRT(
model=model,
input_names=["sample", "timestep", "cond"],
output_names=["action"]
)
trt_model.convert(
example_inputs=(sample, timestep, cond)
)
trt_model.save("models/model.plan")trt_model = TorchTRT(model)
trt_model.load("models/model.plan")result = trt_model.forward((sample, timestep, cond))
print(result)
```