Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/tariqdaouda/abyme
Fractals for Deep and Sophisticated (Training) Loops
https://github.com/tariqdaouda/abyme
deep-learning deep-neural-networks deeplearning
Last synced: 24 days ago
JSON representation
Fractals for Deep and Sophisticated (Training) Loops
- Host: GitHub
- URL: https://github.com/tariqdaouda/abyme
- Owner: tariqdaouda
- License: apache-2.0
- Created: 2019-06-28T01:07:12.000Z (over 5 years ago)
- Default Branch: master
- Last Pushed: 2019-09-16T19:48:46.000Z (over 5 years ago)
- Last Synced: 2024-12-28T08:28:55.137Z (28 days ago)
- Topics: deep-learning, deep-neural-networks, deeplearning
- Language: Jupyter Notebook
- Homepage:
- Size: 117 KB
- Stars: 0
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Abyme (Fractals)
Abyme is a tool for writing Deep and Sophisticated (Training) Loops.
Training loops involve a lot cuisine:
* When to save a model
* What to print on screen
* When?
* What information capture for debugging
* In what format save them?
* At which periodicity?With Abyme training loops are written as fractals that go deeper and deeper, allowing the user to dynamically plug events at *user-defined* steps. Sounds complicated but it actually makes everyting much simpler.
```python
criterion = torch.nn.modules.loss.MSELoss()
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01)epoch_looper = AB.IterrationLooper()
train_data_looper = AB.DataLooper(get_data_loader(train=True, mask_targets=True, batch_size=500))
train_pass = AP.SupervisedPass(model, optimizer, criterion, update_parameters=True, inputs_targets_formater=data_formater)
train_stats = AB.Stats(caller_field="last_loss")test_data_looper = AB.DataLooper(get_data_loader(train=False, mask_targets=True, batch_size=10000))
test_pass = AP.SupervisedPass(model, optimizer, criterion, update_parameters=False, inputs_targets_formater=data_formater)
test_stats = AB.Stats(caller_field="last_loss")csv_result = AB.CSVWriter(filename="test2.csv")
def handle_epoch_end(name, epoch_looper, data_looper, csv, save_model, stats_caller_focus):
res = (
AB.NewLowTrigger("average").focus(stats_caller_focus)("dig",
AB.Print(["==>New %s average low, epoch"%name, epoch_looper.get('counter'), "batch:", data_looper.get("counter")]),
AB.If(condition=save_model)("dig",
AP.SaveModel(model=model, filename=name, prefix=epoch_looper.get("counter")),
),
AB.PrettyPrintStore(fields=["average", "std", "min", "max"], prefix="%s.new.low." % name),
csv.add_caller_to_line(fields=["average", "std", "min", "max"], prefix="%s.new.low." % name),
),
AB.MovingStats("average", window_size=100).focus(stats_caller_focus)("dig",
AB.PeriodicTrigger(100, wait_periods=1)("dig",
AB.PrettyPrintStore(fields=["average", "std", "min", "max"], prefix="%s.loss.moving." % name),
csv.add_caller_to_line(fields=["average", "std", "min", "max"], prefix="%s.loss.moving." % name),
)
),
)
return res
AB.Ground()("dig",
epoch_looper.setup(10)("start",
AB.Print(["Training starts"])
).at("iteration_start",
csv_result.open_line(),
train_data_looper("iteration_end",
train_pass("end",
train_stats,
)
).at("end",
test_data_looper("iteration_end",
test_pass("end",
test_stats,
),
),
*handle_epoch_end("train", epoch_looper, train_data_looper, csv_result, save_model=True, stats_caller_focus=train_stats),
*handle_epoch_end("test", epoch_looper, test_data_looper, csv_result, save_model=True, stats_caller_focus=test_stats)
)
).at("iteration_end",
csv_result.commit_line(),
csv_result.save(),
test_stats.reset,
train_stats.reset
).at("end",
AB.Print("End of training")
)
).dig()
```