https://github.com/joshd898/caretmultimodal
Multimodal model training in R
https://github.com/joshd898/caretmultimodal
caret ensemble-learning multimodal-learning r testthat
Last synced: 8 months ago
JSON representation
Multimodal model training in R
- Host: GitHub
- URL: https://github.com/joshd898/caretmultimodal
- Owner: JoshD898
- License: other
- Created: 2025-01-11T22:27:31.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2025-05-14T18:15:24.000Z (10 months ago)
- Last Synced: 2025-06-01T06:53:02.971Z (9 months ago)
- Topics: caret, ensemble-learning, multimodal-learning, r, testthat
- Language: R
- Homepage:
- Size: 2.28 MB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE.md
Awesome Lists containing this project
README
[](https://app.codecov.io/gh/JoshD898/caretMultimodal)
[](https://github.com/JoshD898/caretMultimodal/actions/workflows/R-CMD-check.yaml)
[](https://www.codefactor.io/repository/github/joshd898/caretmultimodal)

# caretMultimodal
CaretMultimodal is a wrapper around the
[caret](https://github.com/topepo/caret) package that allows for
simplified multi-dataset training and ensembling. It is heavily inspired
by Zach Mayer's
[caretEnsemble](https://github.com/zachmayer/caretEnsemble) package.
## Example Usage
For the following examples, we will [these publicly available data
sets](https://amritsingh.shinyapps.io/omicsBioAnalytics/) on heart
failure. The data sets are described
[here](https://pubmed.ncbi.nlm.nih.gov/30935638/).
Let's train models on rows 10 - 20 of the **cells, holter, and protein**
data sets to predict patient **hospitalization** using the **random
forest (RF)** method.
### Creating a `caret_list` object
``` r
# Load the heart failure data
load(system.file("sample_data", "HeartFailure.RData", package = "caretMultimodal"))
models <- caretMultimodal::caret_list(
target = demo$hospitalizations[10:20],
data_list = list(cells = cells[10:20,], holter = holter[10:20,], proteins = proteins[10:20,]),
method = "rf"
)
summary(models)
#> The following models were trained: cells_model, holter_model, proteins_model
#>
#> Model metrics:
#> model method metric value sd
#>
#> 1: cells_model rf ROC 0.5 0.5000000
#> 2: holter_model rf ROC 0.8 0.4472136
#> 3: proteins_model rf ROC 1.0 0.0000000
plot(models)
```

### Using `caret_stack` to stack models
The `caret_stack` function trains a new `caret::train` object on the
predictions from models in a `caret_list`. Let's use the **GLMNET**
method to train an ensemble model with the remaining rows of the
**cells, holter, and protein** data sets.
``` r
stack <- caretMultimodal::caret_stack(
caret_list = models,
data_list = list(cells = cells[-(10:20),], holter = holter[-(10:20),], proteins = proteins[-(10:20),]),
target = demo$hospitalizations[-(10:20)],
method = "glmnet"
)
summary(stack)
#> The following models were ensembled: cells_model, holter_model, proteins_model
#>
#> Relative importance:
#> Overall
#> cells_model 36.65609
#> holter_model 15.42738
#> proteins_model 47.91653
#>
#> Model metrics (based on caret_stack training data):
#> model method metric value sd
#>
#> 1: ensemble glmnet ROC 0.7392857 0.16540766
#> 2: cells_model rf ROC 0.5977564 0.11202056
#> 3: holter_model rf ROC 0.6666667 0.08445071
#> 4: proteins_model rf ROC 0.6602564 0.09410487
predict(
stack,
new_data_list = list(cells = cells, holter = holter, proteins = proteins)
)
#> Yes
#>
#> 1: 0.25371960
#> 2: 0.07682839
#> 3: 0.30947239
#> 4: 0.26927079
#> ...
plot(stack)
```

## Installation
The package can be installed using devtools
``` r
devtools::install_github("JoshD898/caretMultimodal")
```
## Project Structure
This project generally follows the [Tidyverse style
guide](https://style.tidyverse.org/).
### Naming Conventions
- Internal (non-exported) functions are prefixed with `.` to hide them
from the package namespace.
### File Organization
Each object’s definition and its associated methods are contained within
a single file.
- **`caret_list.R`** – Defines the `caret_list` object and its
methods.\
- **`caret_stack.R`** – Defines the `caret_stack` object and its
methods.\
- **`helpers.R`** – Contains internal helper functions shared across
multiple objects.