https://github.com/FrozenAssassine/NeuralNetwork-Arduino
A lightweight neural network library for ESP32 and Arduino, enabling on-device training and simple predictions like XOR. Ideal for microcontroller-based AI projects with resource constraints. 🔧🚀
https://github.com/FrozenAssassine/NeuralNetwork-Arduino
ai arduino cpp esp32 local neural-network neural-networks xor
Last synced: about 1 year ago
JSON representation
A lightweight neural network library for ESP32 and Arduino, enabling on-device training and simple predictions like XOR. Ideal for microcontroller-based AI projects with resource constraints. 🔧🚀
- Host: GitHub
- URL: https://github.com/FrozenAssassine/NeuralNetwork-Arduino
- Owner: FrozenAssassine
- Created: 2024-09-22T15:51:46.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2025-03-05T10:26:25.000Z (about 1 year ago)
- Last Synced: 2025-03-22T21:04:47.386Z (about 1 year ago)
- Topics: ai, arduino, cpp, esp32, local, neural-network, neural-networks, xor
- Language: C++
- Homepage: https://medium.com/@FrozenAssassine/neural-network-from-scratch-on-esp32-2a53a7b65f9f
- Size: 336 KB
- Stars: 7
- Watchers: 2
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
Neural Network for ESP32 and Arduino
## 🤔 What is this project?
This project is a lightweight neural network implementation designed to run on microcontrollers like the **ESP32** and **Arduino**. It demonstrates how even resource-constrained devices can train and perform simple tasks like **XOR** prediction. Maybe you’ll find a use case for simple robot projects.
While it takes just some **seconds** to train on the ESP32, the Arduino requires significantly more time due to limited processing power.
## 📎 [Blog to this project](https://medium.com/@FrozenAssassine/neural-network-from-scratch-on-esp32-2a53a7b65f9f)
## 🛠️ Features
- **On-device training**: Train your neural network directly on ESP32 or Arduino.
- **XOR**: Predict simple numbers like in xor.
- **Activation Functions**: Use activation functions like Sigmoid, Relu, Softmax, TanH and LeakyRelu
- **Fast Training**: The ESP32 can train in just a few seconds, while the Arduino requires longer due to its slow processor.
- **Xavier Initialization**: Optimizes weight distribution for faster training.
## 🔮 Future features
- Train on PC and load weights to chip
- Save and load weights
- More layer types
## 🚀 Performance
- ESP32: Fast training (~seconds).
- Arduino: Slower training (~minutes or more).
## 🫶 Code considerations
I tried to keep the code as simple and easy to understand as possible. The neural network is completely built using OOP principles, which means that everything is its own class. This is useful for structuring the model later.
For the individual layers, I used the basic principle of inheritance, where I have a BaseLayer class and each layer inherits from it. The BaseLayer also implements some functions, like Train and FeedForward, as well as pointers to the weights, values, biases, and errors. In my inherited classes, I only have to override these functions with the training logic and variable implementations. This is very useful when adding new layers.
## 🏗️ How to Use
1. Clone this repository and open the project in Arduino IDE.
2. Upload the code to your ESP32 or Arduino using Arduino IDE
3. Monitor the predictions via Serial Monitor at 115200 baud rate.
Here is an example code:
```cpp
#include "Layers.h"
#include "NeuralNetwork.h"
void setup() {
Serial.begin(115200);
NeuralNetwork *nn = new NeuralNetwork(3);
nn->StackLayer(new InputLayer(2));
nn->StackLayer(new DenseLayer(4, ActivationKind::TanH));
nn->StackLayer(new OutputLayer(1, ActivationKind::Sigmoid));
nn->Build();
float inputs[4][2] = { { 0, 0 }, { 0, 1 }, { 1, 0 }, { 1, 1 } };
float desired[4][1] = { { 0 }, { 1 }, { 1 }, { 0 } };
nn->Train((float*)inputs, (float*)desired, 4, 2, 600, 0.1f);
// Predict XOR results:
for (int i = 0; i < 4; i++) {
float *pred = nn->Predict(inputs[i], 2);
Serial.print("PREDICTION ");
Serial.print(inputs[i][0]);
Serial.print(" ");
Serial.print(inputs[i][1]);
Serial.print(" = ");
Serial.println(pred[0]);
}
}
void loop() {
delay(1000);
}
```
# 📷 Images:
