Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/TristanBilot/mlx-GCN
MLX implementation of GCN, with benchmark on MPS, CUDA and CPU (M1 Pro, M2 Ultra, M3 Max).
https://github.com/TristanBilot/mlx-GCN
apple cuda deep-learning gnn mlx pytorch
Last synced: 14 days ago
JSON representation
MLX implementation of GCN, with benchmark on MPS, CUDA and CPU (M1 Pro, M2 Ultra, M3 Max).
- Host: GitHub
- URL: https://github.com/TristanBilot/mlx-GCN
- Owner: TristanBilot
- License: mit
- Created: 2023-12-11T09:40:09.000Z (11 months ago)
- Default Branch: main
- Last Pushed: 2023-12-16T22:19:53.000Z (11 months ago)
- Last Synced: 2024-08-16T23:25:08.678Z (3 months ago)
- Topics: apple, cuda, deep-learning, gnn, mlx, pytorch
- Language: Python
- Homepage:
- Size: 188 KB
- Stars: 19
- Watchers: 4
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Graph Convolutional Network in MLX
An example of [GCN](https://arxiv.org/pdf/1609.02907.pdf%EF%BC%89) implementation with MLX. Other examples are available here.
The actual benchmark on **M1 Pro**, **M2 Ultra**, **M3 Max** and **Tesla V100**s is explained in this Medium article.
### Install env and requirements
```
CONDA_SUBDIR=osx-arm64 conda create -n mlx python=3.10 numpy pytorch scipy requests -c conda-forgeconda activate mlx
pip install mlx
```### Run
To try the model, just run the `main.py` file. This will download the Cora dataset, run the training and testing. The actual MLX code is located in `main.py`, whereas the PyTorch equivalent is in `main_torch.py`.```
python main.py
```### Run benchmark
To run the benchmark on CUDA device, a new env needs to be set up without the `CONDA_SUBDIR=osx-arm64` prefix, to be in i386 mode and not arm. For all other experiments on arm and Apple Silicon, just use the env created previously.
```
python benchmark.py --experiment=[ mlx | torch_mps | torch_cpu | torch_cuda ]
```### Process benchmark figure
This needs to install additional packages: `matplotlib` and `scikit-learn`.```
python viz.py
```