Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/tirthasheshpatel/segment_anything_keras
A multi-backend (TensorFlow, PyTorch, JAX, and NumPy) implementation of the Segment Anything model in Keras 3.0
https://github.com/tirthasheshpatel/segment_anything_keras
Last synced: 3 months ago
JSON representation
A multi-backend (TensorFlow, PyTorch, JAX, and NumPy) implementation of the Segment Anything model in Keras 3.0
- Host: GitHub
- URL: https://github.com/tirthasheshpatel/segment_anything_keras
- Owner: tirthasheshpatel
- License: apache-2.0
- Created: 2023-05-11T00:07:59.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-04-02T20:22:22.000Z (9 months ago)
- Last Synced: 2024-10-04T23:22:40.616Z (3 months ago)
- Language: Jupyter Notebook
- Homepage:
- Size: 53.5 MB
- Stars: 31
- Watchers: 3
- Forks: 4
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Segment Anything Model in Multi-Backend Keras
This is an implementation of the Segment Anything predictor and automatic mask
generator in Keras 3.The demos uses KerasCV's Segment Anything model:
- [Predictor demo](Segment_Anything_multi_backend_Keras_Demo.ipynb)
- [Atomatic Mask Generator demo](Segment_Anything_Automatic_Mask_Generator_Demo.ipynb)## Install the package
```shell
pip install git+https://github.com/tirthasheshpatel/segment_anything_keras.git
```Install the required dependencies:
```shell
pip install -U Pillow numpy keras keras-cv
```Install TensorFlow, JAX, or PyTorch, whichever backend you'd like to use.
To get all the dependencies and all the backends to run the demos, do:
```shell
pip install -r requirements.txt
```## Getting the pretrained Segment Anything Model
```python
# Use TensorFlow backend, choose any you want
import os
os.environ['KERAS_BACKEND'] = "tensorflow"from keras_cv.models import SegmentAnythingModel
from sam_keras import SAMPredictor# Get the huge model trained on the SA-1B dataset.
# Other available options are:
# - "sam_base_sa1b"
# - "sam_large_sa1b"
model = SegmentAnythingModel.from_preset("sam_huge_sa1b")# Create the predictor
predictor = SAMPredictor(model)# Now you can use the predictor just like the one on the original repo.
# The only difference is list of input dicts isn't supported; instead
# pass each input dict separately to the `predict` method.
```## Notes
Right now JAX and TensorFlow have large compile-time overhead. Prompt encoder
recompiles each time a different combination of prompts (points only,
points + boxes, boxes only, etc) is passed. To avoid this, compile the model
with `run_eagerly=True` and `jit_compile=False`.