https://github.com/xinxiong0238/gart
Guided Adversarial Robust Transfer Learning with Source Mixing
https://github.com/xinxiong0238/gart
generalization robust-optimization source-mixing transfer-learning
Last synced: 4 months ago
JSON representation
Guided Adversarial Robust Transfer Learning with Source Mixing
- Host: GitHub
- URL: https://github.com/xinxiong0238/gart
- Owner: xinxiong0238
- License: other
- Created: 2024-08-14T19:57:52.000Z (over 1 year ago)
- Default Branch: master
- Last Pushed: 2024-08-16T18:25:24.000Z (over 1 year ago)
- Last Synced: 2024-08-16T19:43:14.851Z (over 1 year ago)
- Topics: generalization, robust-optimization, source-mixing, transfer-learning
- Language: R
- Homepage:
- Size: 15.6 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.Rmd
- License: LICENSE
Awesome Lists containing this project
README
---
output: github_document
---
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# GART
Guided Adversarial Robust Transfer (GART) Learning aims to effectively bridges the realms of transfer learning and distributional robustness prediction models when dealing with a limited amount of target data and a diverse range of source models. By leveraging the source mixing assumption, GART is designed to learn valuable knowledge that may be present in different yet potentially related auxiliary samples, and achieve a faster convergence rate than the model fitted with the target data.
## Installation
You can install the development version of GART from [GitHub](https://github.com/) with:
``` r
# install.packages("devtools")
devtools::install_github("xinxiong0238/GART")
```
## Example
The following is a basic example to run GART. We first generate the training data (a small number of target data and 4 large but heterogenerous source data) as well as a validation data set. For illustration purpose, the validation data share the same generation mechanism as the training target data. Then `GART` function is called to estimate GART parameter and validate model performance. Five benchmarks are also included (i.e., target only estimator, source mixture, maximin, transLasso and transGLM).
```{r example}
library(GART)
data_sim = simu_data()
data = data_sim$GART_est_input
data_valid = data_sim$GART_eval_input
fit = GART(data, is_benchmark = T, is_valid = T, data_valid = data_valid)
```