https://github.com/rockerboo/candle_scheduler
LR Schedulers for Candle
https://github.com/rockerboo/candle_scheduler
candle rust scheduler
Last synced: 9 days ago
JSON representation
LR Schedulers for Candle
- Host: GitHub
- URL: https://github.com/rockerboo/candle_scheduler
- Owner: rockerBOO
- Created: 2023-12-16T22:04:04.000Z (almost 2 years ago)
- Default Branch: main
- Last Pushed: 2023-12-17T21:11:01.000Z (almost 2 years ago)
- Last Synced: 2025-04-06T17:19:28.719Z (6 months ago)
- Topics: candle, rust, scheduler
- Language: Rust
- Homepage:
- Size: 16.6 KB
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# candle_scheduler
[Candle](https://github.com/huggingface/candle) scheduler
- OneCycle
- CosineAnnealing## Install
In `Cargo.toml`
```
candle-scheduler = { git = 'https://github.com/rockerBOO/candle_scheduler.git", rev = "a91c1c9692d8cc1da4f4e56900fae3a81eb4eb41" }
```## Usage
```rust
let varmap = VarMap::new();let params = ParamsAdamW {
// LR here will be set by the scheduler
..Default::default()
};let mut opt = AdamW::new(varmap.all_vars(), params)?;
// total number of steps
let total_steps = 10;// The div factor for the minimum Learning Rate (LR)
let div_factor = 25.;// Set the Max LR
let max_lr = 1e-2;
let mut scheduler = scheduler::OneCycleLR::new(max_lr, total_steps, div_factor);// Learning steps
for i in 0..10 {
// Some logits from the model
let logits = Tensor::rand(...);// Calculate the loss against the target
let loss = loss::mse(&logits, targets);// Backwards pass
opt.backward_step(&loss)?;// Then we update the LR with the scheduler
scheduler.step(&mut opt);println!("{}", scheduler.get_lr());
}
```