https://github.com/masonprotter/shapecheck.jl
https://github.com/masonprotter/shapecheck.jl
Last synced: 4 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/masonprotter/shapecheck.jl
- Owner: MasonProtter
- License: mit
- Created: 2022-08-02T20:34:02.000Z (over 3 years ago)
- Default Branch: master
- Last Pushed: 2022-08-04T15:15:08.000Z (over 3 years ago)
- Last Synced: 2025-01-22T03:32:56.640Z (12 months ago)
- Language: Julia
- Size: 13.7 KB
- Stars: 35
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
[](https://github.com/MasonProtter/ShapeCheck.jl/actions/workflows/CI.yml)
# ShapeCheck.jl
This provides a nice(?) syntax for **runtime** shape checking of function outputs given their inputs. This should work for any type `T` which has methods for `size(::T, n)`. The shapes being checked here in this package can *not* be dispatched on. Inspired by https://github.com/google/jaxtyping
Consider this shapechecked implementation of a `remove_last` function:
``` julia
using ShapeCheck
const AV = AbstractVector
@shapechecked function remove_last(x :: AV[n]) :: AV[n-1]
x[1:end-1]
end
```
The use of square brackets in the type signature of the above function are essentially assertions about the shapes
of the dimensions of the inputs and outputs. It says that `x` must be an `AbstractVector`, and that if
`n = size(x, 1)`, then the output of the function must have `size(result, 1) == n - 1`.
This way, the hidden logic error in our function gets caught:
``` julia
julia> using OffsetArrays
julia> let v = OffsetVector([:a, :b, :c, :d, :e], -2:2)
remove_last(v)
end
ERROR: DimensionMismatch("Dimension 1 of result does not match n - 1 = 4, got 1.")
Stacktrace:
[1] remove_last(x::OffsetVector{Symbol, Vector{Symbol}})
@ Main [...]/ShapeCheck/src/ShapeCheck.jl:37
[2] top-level scope
@ REPL[15]:2
```
Without the shapecheck, this function would have silently returned just `[:d]`. Instead, what we should have written is
``` julia
@shapechecked function remove_last(x :: AV[n]) :: AV[n-1]
x[begin:end-1]
end
```
if we want to be able to handle general `AbstractVector`s correctly.
ShapeCheck.jl performs these checks at runtime, so they are not truly zero cost (unless your shapes are constant propagated).
### Syntax Examples
Suppose you only care about the 2nd dimension of an array, you can always just 'name' a dim `_` and it'll be thrown out.
If you don't care about a inner dim, just use `_` to ignore it:
``` julia
const AA = AbstractArray
@shapechecked function vertical_slice(x::AA[_, n], i) :: AA[n]
x[:, i]
end
```
Dimensions to the *right* of the last one listed are ignored by default.
You can demand that multiple dimensions of arguments match
``` julia
@shapechecked function my_vcat(x::AA[a, b], y::AA[c, b]) :: AA[a + c, b]
vcat(x, y)
end
```
and you can even demand that they are some function of another argument
``` julia
@shapechecked function foo(x::Vector[a], y::Vector[min(a, 3)]) :: Number
s = 0.0
for i ∈ 1:min(a, 3)
s += x[i] * y[i]
end
s
end
```