Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/networks-learning/counterfactual-llms

Code for "Counterfactual Token Generation in Large Language Models", Arxiv 2024.
https://github.com/networks-learning/counterfactual-llms

counterfactuals llama3 llama3-8b llm mistral mistral-small sampling-methods token-generation

Last synced: 9 days ago
JSON representation

Code for "Counterfactual Token Generation in Large Language Models", Arxiv 2024.

Awesome Lists containing this project

README

        

# Counterfactual Token Generation in Large Language Models

This repository contains the code used in the paper ["Counterfactual Token Generation in Large Language Models"](https://arxiv.org/abs/2409.17027) by Ivi Chatzi, Nina Corvelo Benz, Eleni Straitouri, Stratis Tsirtsis, and Manuel Gomez-Rodriguez.

__Contents__:
- [Introduction](#introduction)
- [Dependencies](#dependencies)
- [Repository structure](#repository-structure)
- [Contact & attribution](#contact--attribution)

## Introduction



*"Sure, I am happy to generate a story for you: Captain Lyra stood at the helm of her trusty ship, the Maelstrom’s Fury, gazing out at the endless sea. \[...\] Lyra’s eyes welled up with tears as she realized the bitter truth - she had sacrificed everything for fleeting riches, and lost the love of her crew, her family, and herself."* Although this story, generated by a large language model, is captivating, one may wonder—how would the story have unfolded if the model had chosen "Captain Maeve" as the protagonist instead? We cannot know. State-of-the-art large language models are stateless—they maintain no internal memory or state. Given a prompt, they generate a sequence of tokens as an output using an autoregressive process. As a consequence, they cannot reason about counterfactual alternatives to tokens they have generated in the past. In this work, our goal is to enhance them with this functionality. To this end, we develop a causal model of token generation that builds upon the Gumbel-Max structural causal model. Our model allows any large language model to perform counterfactual token generation at almost no cost in comparison with vanilla token generation, it is embarrassingly simple to implement, and it does not require any fine-tuning nor prompt engineering. We implement our model on Llama 3 8B-Instruct and Ministral-8B-Instruct, and conduct both qualitative and quantitative analyses of counterfactually generated text. We conclude with a demonstrative application of counterfactual token generation for bias detection, unveiling interesting insights about the model of the world constructed by large language models.

## Dependencies

All the experiments were performed using Python 3.11.2. In order to create a virtual environment and install the project dependencies you can run the following commands:

```bash
python3 -m venv env
source env/bin/activate
pip install -r requirements.txt
```

Our code builds upon the popular open-weight large language models Llama 3 8B-instruct and Ministral-8B-instruct. For instructions regarding getting access to the weights of the Llama 3 model refer to the [Llama 3 GitHub repository](https://github.com/meta-llama/llama3) and for the weights of the Ministral model to the Mistral [online documentation](https://docs.mistral.ai/getting-started/models/weights/). In addition to the project dependencies of our own code mentioned above, make sure to first install the project dependencies of the Ministral and Llama 3 models.

## Repository structure

```
├── data
├── figures
├── notebooks
├── outputs
│ ├── llama3
│ │ ├── census*
│ │ └── bias
│ ├── mistral
│ │ ├── census*
│ │ └── bias
│ └── story*
├── scripts
│ ├── bias.py
│ ├── census_queries.py
│ ├── cf_query.py
│ ├── stability.py
│ └── story_query.py
└── src
├── llama3
│ ├── llama
│ │ ├── generation.py
│ │ └── sampler.py
│ └── pretrained
├── mistral-inference
│ └── src/mistral-inference
│ ├── generate.py
│ └── 8B-Instruct
├── bias.py
├── cf_query.py
├── sampler.py
├── single_query.py
├── stability.py
└── utils.py
```

- `data` contains configuration files for our experiments.
- `figures` contains all the figures presented in the paper.
- `notebooks` contains python notebooks to generate all the figures included in the paper.
- `outputs/`, where `` is either `llama3` or `mistral`, contains intermediate output files generated by the experiments' scripts. Specifically:
- `bias` contains the counterfactual census data of Section 4.3.
- `census*` directories contain the factual census data of Section 4.3.
- `outputs/story*` directories contain the results of Section 4.1 and Appendix A.
- `scripts` contains a set of scripts used to run all the experiments presented in the paper.
- `src` contains all the code necessary to reproduce the results in the paper. Specifically:
- `llama3` contains the code of the LLama 3 8B-instruct. Therein:
- `llama/generation.py` uses the LLM to perform factual/counterfactual token generation.
- `llama/pretrained/` is a placeholder directory where the weights of the (pre-trained) LLM should be placed.
- `mistral-inference` contains the code of the Ministral 8B-instruct. Therein:
- `src/mistral-inference/generate.py` uses the LLM to perform factual/counterfactual token generation.
- `src/mistral-inference/8B-Instruct/` is a placeholder directory where the weights of the (pre-trained) LLM should be placed.
- `sampler.py` samples from a token distribution using a Gumbel-Max SCM or its top-p and top-k variants.
- `bias.py` performs counterfactual and international token generation for the experiments of Section 4.3 using the LLM generated census data.
- `cf_query.py` performs counterfactual token generation for a single query.
- `single_query.py` performs factual token generation for a single query. It creates and saves its results in a subdirectory of `outputs/`, where `` is either `llama3` or `mistral`. The results are then used by `cf_query.py`.
- `stability.py` performs interventional and counterfactual token generation for the experiments of Section 4.2.
- `utils.py` contains auxiliary functions for plotting.

## Contact & attribution

In case you have questions about the code, you identify potential bugs or you would like us to include additional functionalities, feel free to open an issue or contact [Ivi Chatzi](mailto:[email protected]) or [Stratis Tsirtsis](mailto:[email protected]).

If you use parts of the code in this repository for your own research, please consider citing:

@article{chatzi2024counterfactual,
title={Counterfactual Token Generation in Large Language Models},
author={Ivi Chatzi and Nina Corvelo Benz and Eleni Straitouri and Stratis Tsirtsis and Manuel Gomez-Rodriguez},
year={2024},
journal={arXiv preprint arXiv:2409.17027}
}