Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/KhoomeiK/LlamaGym
Fine-tune LLM agents with online reinforcement learning
https://github.com/KhoomeiK/LlamaGym
Last synced: 30 days ago
JSON representation
Fine-tune LLM agents with online reinforcement learning
- Host: GitHub
- URL: https://github.com/KhoomeiK/LlamaGym
- Owner: KhoomeiK
- License: mit
- Created: 2024-03-01T03:03:38.000Z (10 months ago)
- Default Branch: main
- Last Pushed: 2024-03-19T17:34:28.000Z (9 months ago)
- Last Synced: 2024-05-04T00:01:22.779Z (7 months ago)
- Language: Python
- Homepage:
- Size: 1.19 MB
- Stars: 908
- Watchers: 8
- Forks: 36
- Open Issues: 6
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
Fine-tune LLM agents with online reinforcement learning
🔗 Agents for Web Data Extraction
•# LlamaGym
"Agents" originated in reinforcement learning, where they learn by interacting with an environment and receiving a reward signal. However, LLM-based agents today do not learn online (i.e. continuously in real time) via reinforcement.OpenAI created [Gym](https://github.com/Farama-Foundation/Gymnasium) to standardize and simplify RL environments, but if you try dropping an LLM-based agent into a Gym environment for training, you'd find it's still quite a bit of code to handle LLM conversation context, episode batches, reward assignment, PPO setup, and more.
LlamaGym seeks to simplify fine-tuning LLM agents with RL. Right now, it's a single `Agent` abstract class that handles all the issues mentioned above, letting you quickly iterate and experiment with agent prompting & hyperparameters across any Gym environment.
## Usage
Fine-tuning an LLM-based agent to play in a Gym-style environment with RL has never been easier! Once you install LlamaGym...
```
pip install llamagym
```First, implement 3 abstract methods on the Agent class:
```python
from llamagym import Agentclass BlackjackAgent(Agent):
def get_system_prompt(self) -> str:
return "You are an expert blackjack player."def format_observation(self, observation) -> str:
return f"Your current total is {observation[0]}"def extract_action(self, response: str):
return 0 if "stay" in response else 1
```Then, define your base LLM (as you would for any fine-tuning job) and instantiate your agent:
```python
model = AutoModelForCausalLMWithValueHead.from_pretrained("Llama-2-7b").to(device)
tokenizer = AutoTokenizer.from_pretrained("Llama-2-7b")
agent = BlackjackAgent(model, tokenizer, device)
```Finally, write your RL loop as usual and simply call your agent to act, reward, and terminate:
```python
env = gym.make("Blackjack-v1")for episode in trange(5000):
observation, info = env.reset()
done = Falsewhile not done:
action = agent.act(observation) # act based on observation
observation, reward, terminated, truncated, info = env.step(action)
agent.assign_reward(reward) # provide reward to agent
done = terminated or truncatedtrain_stats = agent.terminate_episode() # trains if batch is full
```Some reminders:
- above code snippets are mildly simplified above but a fully working example is available in [`examples/blackjack.py`](https://github.com/KhoomeiK/LlamaGym/blob/main/examples/blackjack.py)
- getting online RL to converge is notoriously difficult so you'll have to mess with hyperparameters to see improvement
- your model may also benefit from a supervised fine-tuning stage on sampled trajectories before running RL (we may add this feature in the future)
- our implementation values simplicity so is not as compute efficient as e.g. [Lamorel](https://github.com/flowersteam/lamorel), but easier to start playing around with
- LlamaGym is a weekend project and still a WIP, but we love contributions!## Relevant Work
- [Grounding Large Language Models with Online Reinforcement Learning](https://github.com/flowersteam/Grounding_LLMs_with_online_RL)
- [Lamorel: Language Models for Reinforcement Learning](https://github.com/flowersteam/lamorel)
- [True Knowledge Comes from Practice: Aligning LLMs with Embodied Environments via Reinforcement Learning](https://github.com/WeihaoTan/TWOSOME)## Citation
```
bibtex
@misc{pandey2024llamagym,
title = {LlamaGym: Fine-tune LLM agents with Online Reinforcement Learning},
author = {Rohan Pandey},
year = {2024},
howpublished = {GitHub},
url = {https://github.com/KhoomeiK/LlamaGym}
}
```