https://github.com/shawonashraf/postagger-lstm-jax
LSTM POS Tagger implementation in Jax and Flax
https://github.com/shawonashraf/postagger-lstm-jax
flax jax optax wandb
Last synced: 6 months ago
JSON representation
LSTM POS Tagger implementation in Jax and Flax
- Host: GitHub
- URL: https://github.com/shawonashraf/postagger-lstm-jax
- Owner: ShawonAshraf
- Created: 2024-01-15T19:05:33.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-01-16T01:04:06.000Z (over 1 year ago)
- Last Synced: 2025-02-08T05:27:26.565Z (8 months ago)
- Topics: flax, jax, optax, wandb
- Language: Python
- Homepage:
- Size: 22.5 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# postagger-lstm-jax
A single layer LSTM part-of-speech tagger implemented in JAX (+Flax) on the `batterydata/pos_tagging` dataset
from Huggingface Datasets.## Usage
Make sure that you have a wandb account and have logged in using your API key.
```bash
wandb login
```Then run `main.py` with the following arguments:
```bash
python main.py --lr 0.01 --epochs 5 --batch-size 128 --seed 2023 --dropout 0.2 \
--embedding-dim 300 --hidden-dim 300 --max_seq_len 300 \
--pad_token_idx 1 --log_every_n_step 100
```_The Trainer module is defined to train, evaluate and log to wandb simultaneously._
## Logs and Results
Check the wandb metrics [here](https://wandb.ai/shawonashraf/postagger-lstm-jax/runs/bs5n1ukb?workspace=user-shawonashraf).## Environment Setup
__Version Requirements:__
- Python 3.11
- CUDA 12.2```bash
python -m venv venv
source venv/bin/activate
pip install -r requirements.txt -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html --extra-index-url https://download.pytorch.org/whl/cpu
```