https://github.com/kvignesh1420/cot-icl-lab
[ACL 2025] CoT-ICL Lab: A Synthetic Framework for Studying Chain-of-Thought Learning from In-Context Demonstrations
https://github.com/kvignesh1420/cot-icl-lab
chain-of-thought graphs in-context-learning llm-inference llm-training transformers
Last synced: 3 months ago
JSON representation
[ACL 2025] CoT-ICL Lab: A Synthetic Framework for Studying Chain-of-Thought Learning from In-Context Demonstrations
- Host: GitHub
- URL: https://github.com/kvignesh1420/cot-icl-lab
- Owner: kvignesh1420
- License: mit
- Created: 2025-02-25T20:38:24.000Z (7 months ago)
- Default Branch: main
- Last Pushed: 2025-05-23T02:11:15.000Z (4 months ago)
- Last Synced: 2025-06-07T10:06:21.882Z (4 months ago)
- Topics: chain-of-thought, graphs, in-context-learning, llm-inference, llm-training, transformers
- Language: Python
- Homepage: https://arxiv.org/abs/2502.15132
- Size: 555 KB
- Stars: 10
- Watchers: 2
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
CoT-ICL Lab
A Synthetic Framework for Studying Chain-of-Thought Learning from In-Context Demonstrations

****************************************************************
## Setup
- Create a virtual environment and install the package.
```bash
$ python3.12 -m venv .venv
$ source .venv/bin/activate
(.venv) $ pip install -e .
```- Run unit tests as a sanity check.
```bash
(.venv) $ pytest
```- (Development) Run ruff + isort fixes to sanitize the code changes.
```bash
(.venv) $ ./beautify.sh
```## Getting Started
Our framework serves as a test bed to generate synthetic tokenized datasets for training and evaluating transformer models. We do so by using `DAG` and `TokenProcessor` classes. These can be configured directly by the `Args` dataclass. For example:
```py
from tokenized_cot_icl.core.args import Args
from tokenized_cot_icl.core.data import TokenizedDatasetargs = Args(
vocab_size=1024,
n_inputs=4,
n_parents=2,
chain_length=3,
n_examples=1,
enable_cot=True,
prompt_strategy="cot",
activation="leaky_relu",
n_tasks=10,
)dataset = TokenizedDataset(args=args)
print(dataset[0])
```The above item in the dataset is as follows:
```py
{
'adj_list': tensor([[0, 2], [4, 3], [5, 3]]),
'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1]),
'input_ids': tensor([ 556, 197, 1002, 867, 240, 466, 217]),
'labels': tensor([-100, -100, -100, -100, 240, 466, 217]),
'cot_eval':
{
'attention_mask': tensor([1, 1, 1, 1]),
'input_ids': tensor([ 556, 197, 1002, 867]),
'last_example_cot': tensor([240, 466, 217])
}
}
```### Understanding the DAG structure
Let's break down the result above to understand the DAG structure. Consider $4$ input tokens ($x_1, x_2, x_3, x_4$) and $3$ chain tokens ($y_1, y_2, y_3$) for the single example above.
The `'adj_list': tensor([[0, 2], [4, 3], [5, 3]])` (based on zero-indexing) indicates that the parent tokens for the chain tokens are as follows:
| Chain Token | Parent Tokens |
|-------------|------------------|
| $y_1$ | $\{x_1, x_3\}$ |
| $y_2$ | $\{y_1, x_4\}$ |
| $y_3$ | $\{y_2, x_4\}$ |>[!NOTE]
> The TokenCoverage metric introduced in the paper relies on the uniqueness of chain tokens in the entire dataset and depends heavily on the "vocab_size" and "activation". Thus controlling the difficulty of the tasks.## Models
We leverage the HuggingFace [transformers](https://github.com/huggingface/transformers) library to create custom Llama models and expose a `MODEL_REGISTRY` to register new model families.
```py
# src/tokenized_cot_icl/core/models.pyMODEL_REGISTRY = {"llama": create_llama_model}
```>[!TIP]
> Users can register the creation function for models of their choice from the `transformers` library to explore new architectures and validate ideas.## Training
### Setting the `TASK_CARD`
To make it suitable for bulk launching the experiments, we rely on a `TASK_CARD` to collate all the args. For instance, to train a model with the args as per the above example, we do:
```py
# src/tokenized_cot_icl/core/task_card.pydef custom_task_card() -> Dict[int, Args]:
"""A custom task card."""
args = Args(...) # set as needed
return {0: args}# set the dictionary
TASK_CARD = custom_task_card()
```### Launch the DDP Training
The `TASK_CARD` allows us to index into the experimental config of our choice and launch the torch distributed data parallel (DDP) training runs. For example:
```bash
(.venv) $ cd src
(.venv) $ export NUM_NODES=1 # change as needed
(.venv) $ export LOCAL_WORLD_SIZE=4 # change as needed
(.venv) $ torchrun --nnodes=$NUM_NODES --nproc-per-node=$LOCAL_WORLD_SIZE -m tokenized_cot_icl.core.train --task_card_key 0
```### Metric Logging
- By default, we use `metric_logger="stdout"` in `Args` and log the metrics/params to `STDOUT`.
- We also support logging to an [MLFlow](https://mlflow.org/docs/latest/tracking.html) tracking server by setting the `MLFLOW_SERVICE_URL` environment variable and using `Args(metric_logger="mlflow")`.### Liger-Kernels
Users can also apply the [Liger-Kernel](https://github.com/linkedin/Liger-Kernel) optimizations to patch the llama models by setting `Args(use_liger_kernels=True)` and speed up the training runs.
```bash
(.venv) $ pip install liger-kernel # install suitable version
```## Inference with vLLM/SGLang
In addition to using the `transformers.GenerationConfig` for small scale inference during the training runs, we also support [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) based evaluation of the trained model (or model checkpoints) to analyze the predictions.
```bash
(.venv) $ pip install vllm # install suitable version
(.venv) $ pip install sglang # install suitable version
```We provide an easy to extend example for calculating the answer token prediction accuracy as follows:
```bash
# for vllm
(.venv) $ cd src && python tokenized_cot_icl/inference/vllm/evaluator.py \
--model_base_dir /opt/cot-icl-lab/run_name \
--checkpoint final # either final or 1000, 2000 etc.# for sglang
(.venv) $ cd src && python tokenized_cot_icl/inference/sglang/evaluator.py \
--model_base_dir /opt/cot-icl-lab/run_name \
--checkpoint final # either final or 1000, 2000 etc.
```## License
[MIT License](LICENSE)
## Citation
```bibtex
@misc{kothapalli2025coticllabsyntheticframework,
title={CoT-ICL Lab: A Synthetic Framework for Studying Chain-of-Thought Learning from In-Context Demonstrations},
author={Vignesh Kothapalli and Hamed Firooz and Maziar Sanjabi},
year={2025},
eprint={2502.15132},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.15132},
}
```