An open API service indexing awesome lists of open source software.

https://github.com/t-k-233/diffusion-benchmark


https://github.com/t-k-233/diffusion-benchmark

Last synced: 2 months ago
JSON representation

Awesome Lists containing this project

README

          

# Torch Exporter

## Installation

Install TensorRT environment from [here](https://notes.tk233.xyz/ml-rl/setting-up-tensorrt-environment-on-ubuntu-2x.04).

```bash
pip install -e .
```

## Example Usage

```python
import torch
from torchconverter import TorchTRT

from model import TransformerForDiffusion

# torch model
model = TransformerForDiffusion(device="cpu")

# sample input
sample = torch.rand((1, 16, 12), dtype=torch.float32)
timestep = torch.rand((1, ), dtype=torch.float32)
cond = torch.rand((1, 8, 42), dtype=torch.float32)

# convert to TensorRT
trt_model = TorchTRT(
model=model,
input_names=["sample", "timestep", "cond"],
output_names=["action"]
)
trt_model.convert(
example_inputs=(sample, timestep, cond)
)
trt_model.save("models/model.plan")

trt_model = TorchTRT(model)
trt_model.load("models/model.plan")

result = trt_model.forward((sample, timestep, cond))
print(result)
```