https://github.com/gforge/criterion_ignore
For Torch - A parallel criterion with ignore label
https://github.com/gforge/criterion_ignore
Last synced: 11 months ago
JSON representation
For Torch - A parallel criterion with ignore label
- Host: GitHub
- URL: https://github.com/gforge/criterion_ignore
- Owner: gforge
- License: mit
- Created: 2016-03-14T11:14:15.000Z (over 10 years ago)
- Default Branch: master
- Last Pushed: 2016-10-25T09:43:46.000Z (over 9 years ago)
- Last Synced: 2025-02-03T13:15:31.290Z (over 1 year ago)
- Language: Lua
- Size: 16.6 KB
- Stars: 0
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# The criterion_ignore addon for torch/nn
The package is for use with [torch/nn](https://github.com/torch/nn) and adds a
method for ignoring labels. It is a direct extension of the [ParallelCriterion][1]
where the `:add()` allows you to specify an ignore label for each criterion that you add.
As of version 0.2 you now also have the power of `argcheck` for help with arguments
etc. If you mistype an argument then there is an automated help print.
[1]: https://github.com/torch/nn/blob/master/doc/criterion.md#nn.ParallelCriterion
## Installation
In order to install the package you need to do it directly from the GitHub repo (at the moment):
```bash
luarocks install https://raw.githubusercontent.com/gforge/criterion_ignore/master/rocks/criterion_ignore-0.2-1.rockspec
```
## Use case:
```lua
require 'criterion_ignore'
model = nn.Sequential()
model:add(nn.Linear(3,5))
criterion = nn.ParallelIgnoreCriterion()
prl = nn.ConcatTable()
for i=1,7 do
seq = nn.Sequential()
seq:add(nn.Linear(5,i + 1))
seq:add(nn.SoftMax())
prl:add(seq)
-- First parameter is weight while the second is the ignore label
-- the argcheck allows you though to specify the actual argument names
criterion:add{
criterion = nn.ClassNLLCriterion(),
ignore = 0
}
end
model:add(prl)
input = torch.rand(3)
target = {1,2,3,4,5,6,7}
output = model:forward(input)
print(output)
err1 = criterion:forward(output,target)
print(err1)
target[5] = 0
output = model:forward(input)
print(output)
err2 = criterion:forward(output,target)
print(err2)
print(err1 < err2)
```