https://github.com/slmt/learn-tch-rs
My playground for learning tch-rs (PyTorch in Rust).
https://github.com/slmt/learn-tch-rs
Last synced: 3 months ago
JSON representation
My playground for learning tch-rs (PyTorch in Rust).
- Host: GitHub
- URL: https://github.com/slmt/learn-tch-rs
- Owner: SLMT
- Created: 2022-06-28T14:13:07.000Z (almost 4 years ago)
- Default Branch: master
- Last Pushed: 2022-07-10T15:39:13.000Z (almost 4 years ago)
- Last Synced: 2025-08-25T11:59:41.694Z (9 months ago)
- Language: Rust
- Size: 18.6 KB
- Stars: 5
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# My [tch-rs](https://github.com/LaurentMazare/tch-rs) Learning Playground
Since I am recently working on a big project that needs to use reinforcement learning on Rust, I created this repository for me to learn how to use [tch-rs](https://github.com/LaurentMazare/tch-rs), which is a wrapper written in Rust that provides a interface to C++ PyTorch (libtorch).
## Setup Environment
Just follow the guideline in the README of [tch-rs](https://github.com/LaurentMazare/tch-rs) and run `cargo run --example basics` on `tch-rs`'s repository to test if everything works.
Note: the official PyTorch package has already had built-in CUDA runtime, so we don't need to install CUDA toolkit.
### OpenAI Gym
In order to test reinforcement learning methods, we need to use [OpenAI Gym](https://github.com/openai/gym) to setup environments, so we require OpenAI Gym to be installed on the host machine.
First, the host machine must have Python runtime, and install OpenAI Gym via the following command:
```
> pip install gym
```
To test if we can run OpenAI Gym properly, run the following command:
```
> cargo run --bin gym_test
```
## Setup Data Sets
### MNIST
1. Download the data set from [the website](http://yann.lecun.com/exdb/mnist/)
2. Unpack all the data set
3. Ensure that all the data files are exactly named as follows:
- `train-images-idx3-ubyte`
- `train-labels-idx1-ubyte`
- `t10k-images-idx3-ubyte`
- `t10k-labels-idx1-ubyte`
## Available Binaries
Each of the following binaries can be run with `cargo run --binary [NAME]`.
- `mnist_dense`: an example code for training a dense net with MNIST data set
- `mnist_cnn`: an example code for training a CNN with MNIST data set
- `gym_test`: a test for checking if it can run [OpenAI Gym](https://github.com/openai/gym) properly.
- `cartpole_policy_gradient`: an example code for training a policy gradient RL agent on [Cart Pole](https://www.gymlibrary.ml/environments/classic_control/cart_pole/) environment.
- `cartpole_a2c`: an example code for training an advantage actor-critic (A2C) RL agent on [Cart Pole](https://www.gymlibrary.ml/environments/classic_control/cart_pole/) environment.