Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

https://github.com/agrimgupta92/sgan

Code for "Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks", Gupta et al, CVPR 2018
https://github.com/agrimgupta92/sgan

deep-learning generative-adversarial-network human-trajectory-prediction pytorch social-navigation trajectory-prediction

Last synced: about 2 months ago
JSON representation

Code for "Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks", Gupta et al, CVPR 2018

Lists

README

        

# Social GAN

This is the code for the paper

**Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks**


Agrim Gupta,
Justin Johnson,
Fei-Fei Li,
Silvio Savarese,
Alexandre Alahi


Presented at [CVPR 2018](http://cvpr2018.thecvf.com/)

Human motion is interpersonal, multimodal and follows social conventions. In this paper, we tackle this problem by combining tools from sequence prediction and generative adversarial networks: a recurrent sequence-to-sequence model observes motion histories and predicts future behavior, using a novel pooling mechanism to aggregate information across
people.

Below we show an examples of socially acceptable predictions made by our model in complex scenarios. Each person is denoted by a different color. We denote observed trajectory by dots and predicted trajectory by stars.




If you find this code useful in your research then please cite
```
@inproceedings{gupta2018social,
title={Social GAN: Socially Acceptable Trajectories with Generative Adversarial Networks},
author={Gupta, Agrim and Johnson, Justin and Fei-Fei, Li and Savarese, Silvio and Alahi, Alexandre},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
number={CONF},
year={2018}
}
```

## Model
Our model consists of three key components: Generator (G), Pooling Module (PM) and Discriminator (D). G is based on encoder-decoder framework where we link the hidden states of encoder and decoder via PM. G takes as input trajectories of all people involved in a scene and outputs corresponding predicted trajectories. D inputs the entire sequence comprising both input trajectory and future prediction and classifies them as “real/fake”.



## Setup
All code was developed and tested on Ubuntu 16.04 with Python 3.5 and PyTorch 0.4.

You can setup a virtual environment to run the code like this:

```bash
python3 -m venv env # Create a virtual environment
source env/bin/activate # Activate virtual environment
pip install -r requirements.txt # Install dependencies
echo $PWD > env/lib/python3.5/site-packages/sgan.pth # Add current directory to python path
# Work for a while ...
deactivate # Exit virtual environment
```

## Pretrained Models
You can download pretrained models by running the script `bash scripts/download_models.sh`. This will download the following models:

- `sgan-models/_.pt`: Contains 10 pretrained models for all five datasets. These models correspond to SGAN-20V-20 in Table 1.
- `sgan-p-models/_.pt`: Contains 10 pretrained models for all five datasets. These models correspond to SGAN-20VP-20 in Table 1.

Please refer to [Model Zoo](MODEL_ZOO.md) for results.

## Running Models
You can use the script `scripts/evaluate_model.py` to easily run any of the pretrained models on any of the datsets. For example you can replicate the Table 1 results for all datasets for SGAN-20V-20 like this:

```bash
python scripts/evaluate_model.py \
--model_path models/sgan-models
```

## Training new models
Instructions for training new models can be [found here](TRAINING.md).