https://github.com/phuijse/pytorchfromjava
https://github.com/phuijse/pytorchfromjava
Last synced: 7 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/phuijse/pytorchfromjava
- Owner: phuijse
- Created: 2024-01-21T13:31:55.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-10-02T09:12:22.000Z (about 1 year ago)
- Last Synced: 2025-01-15T13:05:25.272Z (9 months ago)
- Language: C++
- Size: 11.4 MB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Calling Pytorch models from Java
## Contents
- `JavaTorch.java`: A Java Class that calls a function in `JavaTorch.cpp` through the Java Native Interface (JNI)
- `JavaTorch.cpp`: A JNI compatible function that loads a previously traced torch model, performs inference and returns the predicted class
- `model.py`: A script that ann generate or evaluate a simple torch model## Running the example
Download and unzip torchscript 2.1.2, e.g. the CPU version:
wget https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zip
unzip libtorch-cxx11-abi-shared-with-deps-2.1.2%2Bcpu.zipCompile with
make
This will:
- Generate the `JavaTorch.h` header based on the class defined in `JavaTorch.java`
- Compile `JavaTorch.java` into `JavaTorch.class`
- Compile `JavaTorch.cpp` and link with TorchscriptRun:
LD_LIBRARY_PATH=libtorch/lib java -Djava.library.path=. JavaTorch models/traced_dummy.pt
## (Optional) Creating/Running the model in python
Install the python torch library, e.g. with conda/mamba
conda create -n torch pytorch=2.1.2 cpuonly numpy
Evaluate a model with the same dummy input used in `JavaTorch.cpp`
python model.py --mode run --model_path models/traced_dummy.pt
Or recreate the pytorch model and trace it
python model.py --mode create --model_path models/traced_dummy.pt
## TODO/Improvements
- Use cmake instead of custom Makefile, torch provides CMAKE recipes. [JNI is also supported by cmake](https://cmake.org/cmake/help/latest/module/FindJNI.html)
- Generate the dummy data in Java and pass it to the C function. Also, pass the logits to Java instead of only the class
- Trace the preprocessing routines
- Load the model once instead of every time inference is performed. Profile inference.
- Link against libraries in the torch environment instead of additionally downloading `libtorch`## Notes
Tested with:
- openjdk 21.0.2 2024-01-16
- g++ (GCC) 13.2.1 20230801
- torchscript 2.1.2 cpu version