High-powered framework for cross-validation. Fold your data like it’s paper!
Authors: Jeremy Coyle, Nima Hejazi, Ivana Malenica, and Rachael Phillips
origami
?
The origami
R package provides a general framework for the application of cross-validation schemes to particular functions. By allowing arbitrary lists of results, origami
accommodates a range of cross-validation applications.
For standard use, we recommend installing the package from CRAN via
install.packages("origami")
You can install a stable release of origami
from GitHub via devtools
with:
devtools::install_github("tlverse/origami")
For details on how best to use origami
, please consult the package documentation and introductory vignette online, or do so from within R.
This minimal example shows how to use origami
to apply cross-validation to the computation of a simple descriptive statistic using a sample data set. In particular, we obtain a cross-validated estimate of the mean:
library(stringr)
library(origami)
#> origami v1.0.5: Generalized Framework for Cross-Validation
set.seed(4795)
data(mtcars)
head(mtcars)
#> mpg cyl disp hp drat wt qsec vs am gear carb
#> Mazda RX4 21.0 6 160 110 3.90 2.620 16.46 0 1 4 4
#> Mazda RX4 Wag 21.0 6 160 110 3.90 2.875 17.02 0 1 4 4
#> Datsun 710 22.8 4 108 93 3.85 2.320 18.61 1 1 4 1
#> Hornet 4 Drive 21.4 6 258 110 3.08 3.215 19.44 1 0 3 1
#> Hornet Sportabout 18.7 8 360 175 3.15 3.440 17.02 0 0 3 2
#> Valiant 18.1 6 225 105 2.76 3.460 20.22 1 0 3 1
# build a cv_fun that wraps around lm
cv_lm <- function(fold, data, reg_form) {
# get name and index of outcome variable from regression formula
out_var <- as.character(unlist(str_split(reg_form, " "))[1])
out_var_ind <- as.numeric(which(colnames(data) == out_var))
# split up data into training and validation sets
train_data <- training(data)
valid_data <- validation(data)
# fit linear model on training set and predict on validation set
mod <- lm(as.formula(reg_form), data = train_data)
preds <- predict(mod, newdata = valid_data)
# capture results to be returned as output
out <- list(coef = data.frame(t(coef(mod))),
SE = ((preds - valid_data[, out_var_ind])^2))
return(out)
}
folds <- make_folds(mtcars)
results <- cross_validate(cv_fun = cv_lm, folds = folds, data = mtcars,
reg_form = "mpg ~ .")
mean(results$SE)
#> [1] 15.18558
For details on how to write wrappers (cv_fun
s) for use with origami::cross_validate
, please consult the documentation and vignettes that accompany the package.
Contributions are very welcome. Interested contributors should consult our contribution guidelines prior to submitting a pull request.
After using the origami
R package, please cite it:
@article{coyle2018origami,
= {Coyle, Jeremy R and Hejazi, Nima S},
author = {origami: A Generalized Framework for Cross-Validation in R},
title = {The Journal of Open Source Software},
journal = {3},
volume = {21},
number = {January},
month = {2018},
year = {The Open Journal},
publisher = {10.21105/joss.00512},
doi = {https://doi.org/10.21105/joss.00512}
url }
© 2017-2021 Jeremy R. Coyle
The contents of this repository are distributed under the GPL-3 license. See file LICENSE
for details.