Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lommix/snail_nn
Neural network and matrix library with parallized learning build from the ground up. Educational project.
https://github.com/lommix/snail_nn
ai calculus gradient-descent machine-learning matrix-library neural-network parallelism
Last synced: 2 months ago
JSON representation
Neural network and matrix library with parallized learning build from the ground up. Educational project.
- Host: GitHub
- URL: https://github.com/lommix/snail_nn
- Owner: Lommix
- License: apache-2.0
- Created: 2023-07-13T09:51:40.000Z (over 1 year ago)
- Default Branch: master
- Last Pushed: 2023-10-17T22:03:41.000Z (about 1 year ago)
- Last Synced: 2024-09-19T03:14:55.161Z (4 months ago)
- Topics: ai, calculus, gradient-descent, machine-learning, matrix-library, neural-network, parallelism
- Language: Rust
- Homepage:
- Size: 618 KB
- Stars: 3
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# [WIP] Snail NN - smol neural network library
Minimalistic CPU based neural network library with backpropagation and parallelized stochastic gradient descent.
## Examples
Storing images inside the neural network, upscaling and interpolate between them.
```bash
cargo run --example imagepol --release
```![image](docs/example_interpolation.png)
---
The mandatory xor example
```bash
cargo run --example xor --release
```![image](docs/xor.png)
---
Example Code:
```rust
use snail_nn::prelude::*;fn main(){
let mut nn = Model::new(&[2, 3, 1]);
nn.set_activation(Activation::Sigmoid)let mut batch = TrainingBatch::empty(2, 1);
let rate = 1.0;// AND - training data
batch.add(&[0.0, 0.0], &[0.0]);
batch.add(&[1.0, 0.0], &[0.0]);
batch.add(&[0.0, 1.0], &[0.0]);
batch.add(&[1.0, 1.0], &[1.0]);for _ in 0..10000 {
let (w_gradient, b_gradient) = nn.gradient(&batch.random_chunk(2));
nn.learn(w_gradient, b_gradient, rate);
}println!("ouput {:?} expected: 0.0", nn.forward(&[0.0, 0.0]));
println!("ouput {:?} expected: 0.0", nn.forward(&[1.0, 0.0]));
println!("ouput {:?} expected: 0.0", nn.forward(&[0.0, 1.0]));
println!("ouput {:?} expected: 1.0", nn.forward(&[1.0, 1.0]));
}
```## Features
- Sigmoid, Tanh & Relu activation functions
- Parallelized stochastic gradient descent## Todos
- Wgpu compute shaders