Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/torinriley/causal-embedding-correction

Mitigating treatment leakage in text embeddings using embedding decomposition and residualization for robust causal inference.
https://github.com/torinriley/causal-embedding-correction

causal-inference text-embedding

Last synced: 11 days ago
JSON representation

Mitigating treatment leakage in text embeddings using embedding decomposition and residualization for robust causal inference.

Awesome Lists containing this project

README

        

# Techniques to Reduce Treatment Leakage in Text Embeddings

## Problem Overview

### **Treatment Leakage**
Treatment leakage arises when representations (e.g., text embeddings) contain information about the treatment variable. This can distort causal analysis, making it difficult to disentangle the causal effect of treatment from confounding factors.

### **Embedding-Based Models**
Text embeddings, such as those generated by transformer models (e.g., BERT), provide high-dimensional vector representations of textual data. However, their high capacity may encode unintended correlations with treatment variables.

---

## Methodology

### **1. Extracting Text Embeddings**
We use the pretrained models **BERT-base-uncased**, **RoBERTa-base**, and **DistilBERT-base-uncased** to extract embeddings for textual data. The embeddings represent the semantic meaning of the input text in a high-dimensional space.

- **Step:** The `last_hidden_state` of the model is averaged across tokens to produce a single embedding vector for each text instance.

---

### **2. High-Dimensional Treatments**
To demonstrate the methodology, a synthetic high-dimensional treatment variable was generated with 10 independent features. These features simulate a realistic, complex treatment structure.

---

### **3. Embedding Decomposition Using Random Forest Regression**
To isolate and remove treatment-related information from embeddings, we use **Random Forest Regression**.

#### Steps:
1. **Train Regressor:**
- A Random Forest model is trained with the high-dimensional treatment features as input and the original embeddings as the target.
- The model captures nonlinear and high-dimensional relationships between the treatment and embeddings.

2. **Predict Treatment Components:**
- The trained model predicts the treatment-related components of the embeddings.

3. **Partial Residualization:**
- The predicted treatment components are scaled by a parameter \(\alpha\) and subtracted from the original embeddings.
- This method balances treatment de-biasing and the retention of meaningful information.

4. **Propensity Scoring:**
- To further enhance causal validity and address confounding bias, we integrate propensity scores into the methodology.

**Steps:**

- A logistic regression model estimates the propensity scores, which represent the probability of receiving treatment given observed covariates.
These scores are computed for each instance in the dataset based on the high-dimensional treatment features.

- **Evaluate Propensity Scores:**
- The AUC (Area Under the Curve) of the propensity model is calculated to validate its effectiveness.
- High AUC values indicate that the model accurately captures covariate information.

- The propensity scores are used to create inverse propensity weights. These weights adjust for covariate imbalances between treatment and control groups, mitigating confounding bias.

5. **Sensitivity Analysis:**
- To assess the robustness of treatment effect estimates to unobserved confounders, sensitivity analysis simulates potential unmeasured factors and evaluates their impact on propensity scores.

6. **Positivity Check:**
- The methodology ensures that the positivity assumption is not violated by checking for extreme propensity scores (close to 0 or 1). Observations with extreme scores are flagged to mitigate assumption violations.

7. **Balancing Diagnostics:**
- Standardized Mean Differences (SMD) before and after weighting are visualized to assess covariate balance between treatment groups.

8. **Uncertainty Quantification:**
- Bootstrapping is used to estimate confidence intervals for treatment effect estimates, providing robust uncertainty measures.

---

### **4. Validation and Visualizations**

#### **Correlation Analysis**
We calculate the mean absolute correlation between the embedding dimensions and the treatment dimensions to assess treatment leakage:
- **Original Embeddings:** Higher correlations with treatment.
- **Adjusted Embeddings:** Reduced correlations, confirming the removal of treatment-related signals.

#### **Variance Comparison**
A comparison of total variance in the embeddings before and after partial residualization demonstrates the extent to which treatment-related variance is removed while preserving overall variability.

#### **Embedding Scatter Plots**
`t-SNE` scatter plots visualize the structure of embeddings:
- **Original Embeddings:** Show strong clustering based on treatment, indicating leakage.
- **Adjusted Embeddings:** Scatter with reduced clustering, indicating treatment-agnostic embeddings.

#### **Distribution of Residuals**
Histograms of residuals confirm the removal of treatment-related components, as residuals are tightly centered around zero.

#### **Balancing Diagnostics**
Bar plots of standardized mean differences (SMD) before and after weighting illustrate the improved covariate balance achieved through propensity score weighting.

#### **Sensitivity Analysis**
Simulations with unobserved confounders quantify the robustness of propensity score estimates to unmeasured variables.

#### **Mutual Information Analysis**
We measure the mutual information between embeddings and:
- **Treatment Features:** Quantifies treatment leakage in original embeddings.
- **Outcome:** Validates the extent to which embeddings capture predictive information about the outcome.

---

## Results

- **Correlation Metrics:**
- Original Embeddings: Moderate correlation with treatment variables.
- Adjusted Embeddings: Significantly reduced correlations, validating the effectiveness of partial residualization.

- **Variance Analysis:**
- The adjusted embeddings retain a meaningful proportion of variance while eliminating treatment-related components.

- **Visualizations:**
- Scatter plots of original and adjusted embeddings and the distribution of residuals provide qualitative and quantitative evidence of the methodology's success.

- **Positivity Check:**
- No extreme propensity scores detected, indicating that the positivity assumption holds.

- **Mutual Information Results:**
- Mutual Information with Treatment: \(0.449\)
- Mutual Information with Outcome: \(0.418\)

- **Estimated Causal Effect:**
- Causal Effect Estimate: \(0.0002\) with 95% CI: \([0.0002, 0.0003]\)

---

# Embedding Analysis Visualizations

## Variance and Correlation Comparison


Variance Comparison
Correlation Comparison



Screenshot 2024-12-29 at 12 05 05 AM


Screenshot 2024-12-29 at 12 04 56 AM

## Embedding Scatter Plots


Original Embeddings
Adjusted Embeddings



Original Embeddings


Screenshot 2024-12-28 at 11 40 40 PM

---

## Conclusion

This project demonstrates a robust methodology to mitigate treatment leakage in text embeddings using the following:
1. **Random Forest Regression** to model and remove treatment-related components.
2. **Partial Residualization** to balance treatment de-biasing with the retention of meaningful data relationships.
3. **Propensity Score Integration** to address confounding bias and improve covariate balance.
4. **Sensitivity Analysis and Diagnostics** to validate assumptions and ensure robust causal inference.
5. **Uncertainty Quantification** using bootstrapping to provide confidence intervals for treatment effect estimates.

By ensuring embeddings are treatment-agnostic but not treatment-blind, this approach enhances the reliability of causal inference models, enabling more accurate estimation of causal effects.

For complete results: [Results Summary](https://github.com/torinriley/Causal-Embedding-Correction/blob/main/results/Results_Summary.md)

See a disclaimer here: [Disclaimer](https://github.com/torinriley/Causal-Embedding-Correction/blob/main/Disclaimer.md)