https://github.com/lwch/tnn
go版本机器学习框架
https://github.com/lwch/tnn
ai cnn deep-learning golang mlp neural-network rnn
Last synced: about 1 month ago
JSON representation
go版本机器学习框架
- Host: GitHub
- URL: https://github.com/lwch/tnn
- Owner: lwch
- License: mit
- Created: 2023-04-06T07:15:01.000Z (about 2 years ago)
- Default Branch: gotorch
- Last Pushed: 2024-07-28T12:57:13.000Z (9 months ago)
- Last Synced: 2024-10-12T04:44:35.953Z (7 months ago)
- Topics: ai, cnn, deep-learning, golang, mlp, neural-network, rnn
- Language: Go
- Homepage:
- Size: 36.8 MB
- Stars: 8
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# tnn
go版本神经网络框架,支持模型训练和推理,libgotorch库的安装方式请看[libgotorch安装](https://github.com/lwch/gotorch#%E5%AE%89%E8%A3%85)
## 工具
- ~~[minfo](cmd/minfo/): 这是tnn框架中的一个工具,用于查看保存模型的定义信息~~
## 示例
- [xor](example/xor/): 该示例是一个四层的神经网络,用于训练异或运算(回归)
- [sin](example/sin/): 该示例使用rnn网络来进行sin函数的时序任务训练(回归)
- [sin_attention](example/sin_attention/): 使用transformer来实现sin曲线的预测(回归)
- [couplet](example/couplet/): 使用GPT模型来对对联(回归)## 构造网络
首先定义网络的每个层
```go
hiddenLayer := layer.NewDense(10)
relu := activation.NewReLU()
outputLayer := layer.NewDense(1)
```选定一个优化器
```go
optimizer := optimizer.NewAdam()
```定义forward函数
```go
func forward(x *tensor.Tensor) *tensor.Tensor {
y := hiddenLayer.Forward(x)
y = relu.Forward(y)
y = outputLayer.Forward(y)
return y
}
```最后构造出模型并进行模型训练
```go
for i := 0; i < 10; i++ {
loss := loss.NewMse(forward(input), output)
loss.Backward()
optimizer.Step(net.Params())
fmt.Printf("Epoch: %d, Loss: %.05f\n", i, loss.Value())
}
```模型推理方法如下
```go
pred := forward(input)
fmt.Println(pred.Float32Value())
```## 感谢
- [tinynn](https://github.com/borgwang/tinynn)
- [gonum](https://github.com/gonum/gonum)
- [pytorch](https://pytorch.org)
- [cobra](github.com/spf13/cobra)