https://github.com/ashok-arjun/zero-shot-sketch-based-image-retrieval
Zero-shot sketch-based image retrieval using a domain adversarial neural network
https://github.com/ashok-arjun/zero-shot-sketch-based-image-retrieval
computer-vision deep-learning domain-adversarial gradient-reversal-layer image-retrieval machine-learning pytorch sketch-based-image-retrieval zero-shot-learning
Last synced: 2 months ago
JSON representation
Zero-shot sketch-based image retrieval using a domain adversarial neural network
- Host: GitHub
- URL: https://github.com/ashok-arjun/zero-shot-sketch-based-image-retrieval
- Owner: ashok-arjun
- Created: 2020-07-29T06:29:15.000Z (about 5 years ago)
- Default Branch: master
- Last Pushed: 2022-11-22T09:48:36.000Z (almost 3 years ago)
- Last Synced: 2023-03-05T07:42:50.935Z (over 2 years ago)
- Topics: computer-vision, deep-learning, domain-adversarial, gradient-reversal-layer, image-retrieval, machine-learning, pytorch, sketch-based-image-retrieval, zero-shot-learning
- Language: Python
- Homepage:
- Size: 17.8 MB
- Stars: 18
- Watchers: 2
- Forks: 2
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
![]()
# Zero Shot Sketch-Based Image Retrieval
The problem of retrieving images from a large-database using ambiguous sketches has been addressed. This problem has been addressed in the **zero-shot scenario**, where the test sketches/images are from **unseen classes** and the deep feature extractor's output embedding distances of the sketch and the image have been used to retrieve the top k closest images from the image database of the unseen classes.
The standard **triplet loss** has been used, along with a **domain loss**, which is trained to differentiate between sketches and images.
The embedding is passed in to **gradient reversal layer** 1, and then into a domain-classifier network, and then into the domain loss. The **gradient reversal layer** acts as an **identity layer in the forward pass**, and **multiplies the gradient by -1 in the backward pass** which results in the main network learning a **domain-agnostic representation** i.e. to fool the domain-classifier network.
# Architecture Overview

† Diagram created by Arjun Ashok using [app.diagrams.net](http://app.diagrams.net)
# Results on unseen(zero-shot) classes
The below table presents a few qualitative results of our model on unseen test classes. The table presents the top 5 results from left to right.
| Query Sketch 1 |
|:---------------:|
||| Retrieved Image 1 | Retrieved Image 2 | Retrieved Image 3 | Retrieved Image 4 | Retrieved Image 5 |
|:-----------------:|:-----------------:|:-----------------:|:-----------------:|:-----------------:|
||||||| Query Sketch 2 |
|:---------------:|
||| Retrieved Image 1 | Retrieved Image 2 | Retrieved Image 3 | Retrieved Image 4 | Retrieved Image 5 |
|:-----------------:|:-----------------:|:-----------------:|:-----------------:|:-----------------:|
||||||# Instructions
Installation
Please execute the following command to install the required libraries:
```
pip install -r requirements.txt
```Data
Execute ```bash download_data.sh```
Training
The file ```train.py``` can be invoked with the following arguments:
```
usage: train.py [-h] --data_dir DATA_DIR --batch_size BATCH_SIZE
--checkpoint_dir CHECKPOINT_DIR --epochs EPOCHS
[--domain_loss_ratio DOMAIN_LOSS_RATIO]
[--triplet_loss_ratio TRIPLET_LOSS_RATIO]
[--grl_threshold_epoch GRL_THRESHOLD_EPOCH]
[--print_every PRINT_EVERY]Training of SBIR
optional arguments:
-h, --help show this help message and exit
--data_dir DATA_DIR Data directory path. Directory should contain two
folders - sketches and photos, along with 2 .txt files
for the labels
--batch_size BATCH_SIZE
Batch size to process the train sketches/photos
--checkpoint_dir CHECKPOINT_DIR
Directory to save checkpoints
--epochs EPOCHS Number of epochs
--domain_loss_ratio DOMAIN_LOSS_RATIO
Domain loss weight
--triplet_loss_ratio TRIPLET_LOSS_RATIO
Triplet loss weight
--grl_threshold_epoch GRL_THRESHOLD_EPOCH
Threshold epoch for GRL lambda
--print_every PRINT_EVERY
Logging interval in iterations
```It is advised to use a GPU for training. The code automatically detects and uses a GPU, if available.
Inference
The file ```evaluate.py``` can be invoked with the following args:
```
usage: evaluate.py [-h] [--model MODEL] --data DATA [--num_images NUM_IMAGES]
[--num_sketches NUM_SKETCHES] [--batch_size BATCH_SIZE]
[--output_dir OUTPUT_DIR]Evaluation of SBIR
arguments:
-h, --help show this help message and exit
--model MODEL Model checkpoint path
--data DATA Data directory path. Directory should contain two
folders - sketches and photos, along with 2 .txt files
for the labels
--num_images NUM_IMAGES
Number of random images to output for every
sketch
--num_sketches NUM_SKETCHES
Number of random sketches to output
--batch_size BATCH_SIZE
Batch size to process the test sketches/photos
--output_dir OUTPUT_DIR
Directory to save output sketch and images
```It is advised to use a GPU for evaluation. The code automatically detects and uses a GPU, if available.
# References
1. Ganin, Yaroslav et al. "Domain-Adversarial Training Of Neural Networks". Journal of Machine Learning Research, 2016, pp. 1-35, url:http://jmlr.org/papers/v17/15-239.html