{"id":13704422,"url":"https://github.com/csinva/imodelsX","last_synced_at":"2025-05-05T09:33:53.638Z","repository":{"id":60809029,"uuid":"545752139","full_name":"csinva/imodelsX","owner":"csinva","description":"Interpret text data using LLMs (scikit-learn compatible).","archived":false,"fork":false,"pushed_at":"2025-03-16T18:12:52.000Z","size":36704,"stargazers_count":163,"open_issues_count":5,"forks_count":26,"subscribers_count":5,"default_branch":"master","last_synced_at":"2025-04-11T15:28:20.246Z","etag":null,"topics":["ai","deep-learning","explainability","huggingface","interpretability","language-model","machine-learning","ml","natural-language-processing","natural-language-understanding","neural-network","pytorch","scikit-learn","text","text-classification","transformer-models","xai"],"latest_commit_sha":null,"homepage":"https://csinva.io/imodelsX/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/csinva.png","metadata":{"files":{"readme":"readme.md","changelog":null,"contributing":null,"funding":null,"license":"license.md","code_of_conduct":null,"threat_model":null,"audit":null,"citation":"citation.cff","codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2022-10-04T23:19:21.000Z","updated_at":"2025-03-16T18:12:55.000Z","dependencies_parsed_at":"2022-10-05T04:14:11.633Z","dependency_job_id":"f7b4d3dd-e9db-45de-9db1-4dd41e792a15","html_url":"https://github.com/csinva/imodelsX","commit_stats":{"total_commits":262,"total_committers":4,"mean_commits":65.5,"dds":0.06106870229007633,"last_synced_commit":"13d909a97bbd4483d6a27709fc88ce9e25d4cbc6"},"previous_names":[],"tags_count":15,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/csinva%2FimodelsX","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/csinva%2FimodelsX/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/csinva%2FimodelsX/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/csinva%2FimodelsX/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/csinva","download_url":"https://codeload.github.com/csinva/imodelsX/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":252471724,"owners_count":21753239,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["ai","deep-learning","explainability","huggingface","interpretability","language-model","machine-learning","ml","natural-language-processing","natural-language-understanding","neural-network","pytorch","scikit-learn","text","text-classification","transformer-models","xai"],"created_at":"2024-08-02T21:01:09.314Z","updated_at":"2025-05-05T09:33:53.627Z","avatar_url":"https://github.com/csinva.png","language":"Python","readme":"\u003cp align=\"center\"\u003e  \u003cimg src=\"https://microsoft.github.io/aug-models/embgam_gif.gif\" width=\"18%\"\u003e \n\u003cimg align=\"center\" width=40% src=\"https://csinva.io/imodelsX/imodelsx_logo.svg?sanitize=True\u0026kill_cache=1\"\u003e \u003c/img\u003e\t\u003cimg src=\"https://microsoft.github.io/aug-models/embgam_gif.gif\" width=\"18%\"\u003e\u003c/p\u003e\n\n\u003cp align=\"center\"\u003eScikit-learn friendly library to explain, predict, and steer text models/data.\u003cbr/\u003eAlso a bunch of utilities for getting started with text data.\n\u003c/p\u003e\n\u003cp align=\"center\"\u003e\n  \u003ca href=\"https://github.com/csinva/imodelsX/tree/master/demo_notebooks\"\u003e📖 demo notebooks\u003c/a\u003e\n\u003c/p\u003e\n\u003cp align=\"center\"\u003e\n  \u003cimg src=\"https://img.shields.io/badge/license-mit-blue.svg\"\u003e\n  \u003cimg src=\"https://img.shields.io/badge/python-3.9+-blue\"\u003e\n  \u003cimg src=\"https://img.shields.io/pypi/v/imodelsx?color=green\"\u003e  \n\u003c/p\u003e  \n\n\n**Explainable modeling/steering**\n\n| Model                       | Reference                                                    | Output  | Description                                                  |\n| :-------------------------- | ------------------------------------------------------------ | ------- | ------------------------------------------------------------ |\n| Tree-Prompt            | [🗂️](http://csinva.io/imodelsX/treeprompt/treeprompt.html), [🔗](https://github.com/csinva/tree-prompt/tree/main), [📄](https://arxiv.org/abs/2310.14034), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/tree_prompt.ipynb),  | Explanation\u003cbr/\u003e+ Steering | Generates a tree of prompts to\u003cbr/\u003esteer an LLM (*Official*) |\n| iPrompt            | [🗂️](http://csinva.io/imodelsX/iprompt/api.html#imodelsx.iprompt.api.explain_dataset_iprompt), [🔗](https://github.com/csinva/interpretable-autoprompting), [📄](https://arxiv.org/abs/2210.01848), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/iprompt.ipynb) | Explanation\u003cbr/\u003e+ Steering | Generates a prompt that\u003cbr/\u003eexplains patterns in data (*Official*) |\n| AutoPrompt            | ㅤㅤ[🗂️](), [🔗](https://github.com/ucinlp/autoprompt), [📄](https://arxiv.org/abs/2010.15980) | Explanation\u003cbr/\u003e+ Steering | Find a natural-language prompt\u003cbr/\u003eusing input-gradients|\n| D3            | [🗂️](http://csinva.io/imodelsX/d3/d3.html#imodelsx.d3.d3.explain_dataset_d3), [🔗](https://github.com/ruiqi-zhong/DescribeDistributionalDifferences), [📄](https://arxiv.org/abs/2201.12323), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/d3.ipynb) | Explanation | Explain the difference between two distributions |\n| SASC            |   ㅤㅤ[🗂️](https://csinva.io/imodelsX/sasc/api.html), [🔗](https://github.com/microsoft/automated-explanations), [📄](https://arxiv.org/abs/2305.09863) | Explanation | Explain a black-box text module\u003cbr/\u003eusing an LLM (*Official*) |\n| Aug-Linear            | [🗂️](https://csinva.io/imodelsX/auglinear/auglinear.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://www.nature.com/articles/s41467-023-43713-1), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb) | Linear model | Fit better linear model using an LLM\u003cbr/\u003eto extract embeddings (*Official*) |\n| Aug-Tree            | [🗂️](https://csinva.io/imodelsX/augtree/augtree.html), [🔗](https://github.com/microsoft/aug-models), [📄](https://www.nature.com/articles/s41467-023-43713-1), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/aug_imodels.ipynb) | Decision tree | Fit better decision tree using an LLM\u003cbr/\u003eto expand features (*Official*) |\n| QAEmb            | [🗂️](https://csinva.io/imodelsX/qaemb/qaemb.html), [🔗](https://github.com/csinva/interpretable-embeddings), [📄](https://arxiv.org/abs/2405.16714), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/qaemb.ipynb) | Explainable\u003cbr/\u003eembedding | Generate interpretable embeddings\u003cbr/\u003eby asking LLMs questions (*Official*) |\n| KAN            | [🗂️](https://csinva.io/imodelsX/kan/kan_sklearn.html), [🔗](https://github.com/Blealtan/efficient-kan/tree/master), [📄](https://arxiv.org/abs/2404.19756), [📖](https://github.com/csinva/imodelsX/blob/master/demo_notebooks/kan.ipynb) | Small\u003cbr/\u003enetwork | Fit 2-layer Kolmogorov-Arnold network |\n\n\u003cp align=\"center\"\u003e\n\u003ca href=\"https://github.com/csinva/imodelsX/tree/master/demo_notebooks\"\u003e📖\u003c/a\u003eDemo notebooks \u0026emsp; \u003ca href=\"https://csinva.io/imodelsX/\"\u003e🗂️\u003c/a\u003e Doc \u0026emsp; 🔗 Reference code \u0026emsp; 📄 Research paper\n\u003c/br\u003e\n⌛ We plan to support other interpretable algorithms like \u003ca href=\"https://arxiv.org/abs/2205.12548\"\u003eRLPrompt\u003c/a\u003e, \u003ca href=\"https://arxiv.org/abs/2007.04612\"\u003eCBMs\u003c/a\u003e, and \u003ca href=\"https://arxiv.org/abs/2004.00221\"\u003eNBDT\u003c/a\u003e. If you want to contribute an algorithm, feel free to open a PR 😄\n\u003c/p\u003e\n\n**General utilities**\n\n| Model                       | Reference                                                    | \n| :-------------------------- | ------------------------------------------------------------ | \n|  [🗂️](https://csinva.io/imodelsX/llm.html)  LLM wrapper| Easily call different LLMs |\n|  [🗂️](https://csinva.io/imodelsX/data.html)  Dataset wrapper| Download minimially processed huggingface datasets |\n| [🗂️](https://csinva.io/imodelsX/linear_ngram.html) Bag of Ngrams    | Learn a linear model of ngrams |\n| [🗂️](https://csinva.io/imodelsX/linear_finetune.html) Linear Finetune  | Finetune a single linear layer on top of LLM embeddings |\n\n\n# Quickstart\n**Installation**: `pip install imodelsx` (or, for more control, clone and install from source)\n\n**Demos**: see the [demo notebooks](https://github.com/csinva/imodelsX/tree/master/demo_notebooks)\n\n\n# Natural-language explanations\n\n### Tree-prompt\n```python\nfrom imodelsx import TreePromptClassifier\nimport datasets\nimport numpy as np\nfrom sklearn.tree import plot_tree\nimport matplotlib.pyplot as plt\n\n# set up data\nrng = np.random.default_rng(seed=42)\ndset_train = datasets.load_dataset('rotten_tomatoes')['train']\ndset_train = dset_train.select(rng.choice(\n    len(dset_train), size=100, replace=False))\ndset_val = datasets.load_dataset('rotten_tomatoes')['validation']\ndset_val = dset_val.select(rng.choice(\n    len(dset_val), size=100, replace=False))\n\n# set up arguments\nprompts = [\n    \"This movie is\",\n    \" Positive or Negative? The movie was\",\n    \" The sentiment of the movie was\",\n    \" The plot of the movie was really\",\n    \" The acting in the movie was\",\n]\nverbalizer = {0: \" Negative.\", 1: \" Positive.\"}\ncheckpoint = \"gpt2\"\n\n# fit model\nm = TreePromptClassifier(\n    checkpoint=checkpoint,\n    prompts=prompts,\n    verbalizer=verbalizer,\n    cache_prompt_features_dir=None,  # 'cache_prompt_features_dir/gp2',\n)\nm.fit(dset_train[\"text\"], dset_train[\"label\"])\n\n\n# compute accuracy\npreds = m.predict(dset_val['text'])\nprint('\\nTree-Prompt acc (val) -\u003e',\n      np.mean(preds == dset_val['label']))  # -\u003e 0.7\n\n# compare to accuracy for individual prompts\nfor i, prompt in enumerate(prompts):\n    print(i, prompt, '-\u003e', m.prompt_accs_[i])  # -\u003e 0.65, 0.5, 0.5, 0.56, 0.51\n\n# visualize decision tree\nplot_tree(\n    m.clf_,\n    fontsize=10,\n    feature_names=m.feature_names_,\n    class_names=list(verbalizer.values()),\n    filled=True,\n)\nplt.show()\n```\n\n### iPrompt\n\n```python\nfrom imodelsx import explain_dataset_iprompt, get_add_two_numbers_dataset\n\n# get a simple dataset of adding two numbers\ninput_strings, output_strings = get_add_two_numbers_dataset(num_examples=100)\nfor i in range(5):\n    print(repr(input_strings[i]), repr(output_strings[i]))\n\n# explain the relationship between the inputs and outputs\n# with a natural-language prompt string\nprompts, metadata = explain_dataset_iprompt(\n    input_strings=input_strings,\n    output_strings=output_strings,\n    checkpoint='EleutherAI/gpt-j-6B', # which language model to use\n    num_learned_tokens=3, # how long of a prompt to learn\n    n_shots=3, # shots per example\n    n_epochs=15, # how many epochs to search\n    verbose=0, # how much to print\n    llm_float16=True, # whether to load the model in float_16\n)\n--------\nprompts is a list of found natural-language prompt strings\n```\n\n### D3 (DescribeDistributionalDifferences)\n\n```python\nfrom imodelsx import explain_dataset_d3\nhypotheses, hypothesis_scores = explain_dataset_d3(\n    pos=positive_samples, # List[str] of positive examples\n    neg=negative_samples, # another List[str]\n    num_steps=100,\n    num_folds=2,\n    batch_size=64,\n)\n```\n\n### SASC\nHere, we explain a *module* rather than a dataset\n\n```python\nfrom imodelsx import explain_module_sasc\n# a toy module that responds to the length of a string\nmod = lambda str_list: np.array([len(s) for s in str_list])\n\n# a toy dataset where the longest strings are animals\ntext_str_list = [\"red\", \"blue\", \"x\", \"1\", \"2\", \"hippopotamus\", \"elephant\", \"rhinoceros\"]\nexplanation_dict = explain_module_sasc(\n    text_str_list,\n    mod,\n    ngrams=1,\n)\n```\n\n# Aug-imodels\nUse these just a like a scikit-learn model. During training, they fit better features via LLMs, but at test-time they are extremely fast and completely transparent.\n\n```python\nfrom imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor\nimport datasets\nimport numpy as np\n\n# set up data\ndset = datasets.load_dataset('rotten_tomatoes')['train']\ndset = dset.select(np.random.choice(len(dset), size=300, replace=False))\ndset_val = datasets.load_dataset('rotten_tomatoes')['validation']\ndset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))\n\n# fit model\nm = AugLinearClassifier(\n    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',\n    ngrams=2, # use bigrams\n)\nm.fit(dset['text'], dset['label'])\n\n# predict\npreds = m.predict(dset_val['text'])\nprint('acc_val', np.mean(preds == dset_val['label']))\n\n# interpret\nprint('Total ngram coefficients: ', len(m.coefs_dict_))\nprint('Most positive ngrams')\nfor k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:\n    print('\\t', k, round(v, 2))\nprint('Most negative ngrams')\nfor k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:\n    print('\\t', k, round(v, 2))\n```\n\n# KAN\n```python\nimport imodelsx\nfrom sklearn.datasets import make_classification, make_regression\nfrom sklearn.metrics import accuracy_score\nimport numpy as np\n\nX, y = make_classification(n_samples=5000, n_features=5, n_informative=3)\nmodel = imodelsx.KANClassifier(hidden_layer_size=64, device='cpu',\n                               regularize_activation=1.0, regularize_entropy=1.0)\nmodel.fit(X, y)\ny_pred = model.predict(X)\nprint('Test acc', accuracy_score(y, y_pred))\n\n# now try regression\nX, y = make_regression(n_samples=5000, n_features=5, n_informative=3)\nmodel = imodelsx.kan.KANRegressor(hidden_layer_size=64, device='cpu',\n                                  regularize_activation=1.0, regularize_entropy=1.0)\nmodel.fit(X, y)\ny_pred = model.predict(X)\nprint('Test correlation', np.corrcoef(y, y_pred.flatten())[0, 1])\n```\n\n\n# General utilities\n\n### Easy baselines\nEasy-to-fit baselines that follow the sklearn API.\n\n```python\nfrom imodelsx import LinearFinetuneClassifier, LinearNgramClassifier\n# fit a simple one-layer finetune on top of LLM embeddings\nm = LinearFinetuneClassifier(\n    checkpoint='distilbert-base-uncased',\n)\nm.fit(dset['text'], dset['label'])\npreds = m.predict(dset_val['text'])\nacc = (preds == dset_val['label']).mean()\nprint('validation acc', acc)\n```\n\n### LLM wrapper\nEasy API for calling different language models with caching (much more lightweight than [langchain](https://github.com/langchain-ai/langchain)).\n\n```python\nimport imodelsx.llm\n# supports any huggingface checkpoint or openai checkpoint (including chat models)\nllm = imodelsx.llm.get_llm(\n    checkpoint=\"gpt2-xl\",  # text-davinci-003, gpt-3.5-turbo, ...\n    CACHE_DIR=\".cache\",\n)\nout = llm(\"May the Force be\")\nllm(\"May the Force be\") # when computing the same string again, uses the cache\n```\n\n### Data wrapper\nAPI for loading huggingface datasets with basic preprocessing.\n```python\nimport imodelsx.data\ndset, dataset_key_text = imodelsx.data.load_huggingface_dataset('ag_news')\n# Ensures that dset has a split named 'train' and 'validation',\n# and that the input data is contained for each split in a column given by {dataset_key_text}\n```\n\n# Related work\n- imodels package (JOSS 2021 [github](https://github.com/csinva/imodels)) - interpretable ML package for concise, transparent, and accurate predictive modeling (sklearn-compatible).\n- Rethinking Interpretability in the Era of Large Language Models (arXiv 2024 [pdf](https://arxiv.org/abs/2402.01761)) - overview of using LLMs to interpret datasets and yield natural-language explanations\n- Experiments in using clinical rule development: [github](https://github.com/csinva/clinical-rule-development)\n- Experiments in automatically generating brain explanations: [github](https://github.com/microsoft/automated-brain-explanations)\n- Interpretation regularization (ICML 2020 [pdf](https://arxiv.org/abs/1909.13584), [github](https://github.com/laura-rieger/deep-explanation-penalization)) - penalizes CD / ACD scores during training to make models generalize better\n","funding_links":[],"categories":["Tools","Python"],"sub_categories":["Interpretable Models"],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcsinva%2FimodelsX","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fcsinva%2FimodelsX","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcsinva%2FimodelsX/lists"}