https://github.com/licht-t/tf-centernet
CenterNet implementation with Tensorflow 2
https://github.com/licht-t/tf-centernet
centernet object-detection object-recognition python python3 tensorflow tensorflow2
Last synced: 10 months ago
JSON representation
CenterNet implementation with Tensorflow 2
- Host: GitHub
- URL: https://github.com/licht-t/tf-centernet
- Owner: Licht-T
- License: mit
- Created: 2020-09-22T09:02:16.000Z (almost 6 years ago)
- Default Branch: master
- Last Pushed: 2020-09-24T13:52:42.000Z (almost 6 years ago)
- Last Synced: 2024-12-16T11:56:20.735Z (over 1 year ago)
- Topics: centernet, object-detection, object-recognition, python, python3, tensorflow, tensorflow2
- Language: Python
- Homepage:
- Size: 2.06 MB
- Stars: 7
- Watchers: 3
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# tf-centernet
[](https://badge.fury.io/py/tf-centernet)

[CenterNet](https://arxiv.org/abs/1904.07850) implementation with Tensorflow 2.
## Install
```bash
pip instal tf-centernet
```
## Example
### Object detection
```python
import numpy as np
import PIL.Image
import centernet
# Default: num_classes=80
obj = centernet.ObjectDetection(num_classes=80)
# Default: weights_path=None
# num_classes=80 and weights_path=None: Pre-trained COCO model will be loaded.
# Otherwise: User-defined weight file will be loaded.
obj.load_weights(weights_path=None)
img = np.array(PIL.Image.open('./data/sf.jpg'))[..., ::-1]
# The image with predicted bounding-boxes is created if `debug=True`
boxes, classes, scores = obj.predict(img, debug=True)
```

### Pose estimation
```python
import numpy as np
import PIL.Image
import centernet
# Default: num_joints=17
pe = centernet.PoseEstimation(num_joints=17)
# Default: weights_path=None
# num_joints=17 and weights_path=None: Pre-trained COCO model will be loaded.
# Otherwise: User-defined weight file will be loaded.
pe.load_weights(weights_path=None)
# Adjust this for the better prediction
pe.score_threshold = 0.1
img = np.array(PIL.Image.open('./data/chi.jpg'))[..., ::-1]
# The image with predicted keypoints is created if `debug=True`
boxes, keypoints, scores = pe.predict(img, debug=True)
```

## TODO
* [x] Object detection
* [x] Pre-trained model for object detection with Hourglass-104
* [x] Pose estimation
* [x] Pre-trained model for pose estimation with Hourglass-104
* [ ] DLA-34 backbone and pre-trained models
* [ ] Training function and Loss definition
* [ ] Training data augmentation