An increasingly thick wrapper around a data.table
containing the data for a prediction task. This contains metadata about the
particular machine learning problem, including which variables are to be
used as covariates and outcomes.
make_sl3_Task(...)
R6Class
object.
Passes all arguments to the constructor. See documentation for Constructor below.
sl3_Task
object
make_sl3_Task(data, covariates, outcome = NULL, outcome_type = NULL, outcome_levels = NULL,
id = NULL, weights = NULL, offset = NULL, nodes = NULL, column_names = NULL,
folds = NULL, drop_missing_outcome = FALSE, flag = TRUE)
data
A data.frame
or data.table
containing the analytic dataset.
covariates
A character vector of variable names that define the set of covariates.
outcome
A character vector of variable names that define the set of outcomes. Usually just one variable, although some learners support multivariate outcomes. Use sl3_list_learners("multivariate_outcome")
to find such learners.
outcome_type
A Variable_type
object that defines the variable type of the outcome. Alternatively, a character specifying such a type. See variable_type
for details on defining variable types.
outcome_levels
A vector of levels expected for the outcome variable. If outcome_type
is a character, this will be used to construct an appropriate variable_type
object.
id
A character indicating which variable (if any) to be used as an identifier for independent observations, which would be necessary if there are clusters of dependent units in the data (e.g., repeated measures on the same individual). The id
is used to define a clustered cross-validation scheme (if folds
is not already supplied to make_sl3_Task
), for learners that use cross-validation as part of their fitting procedure. Use sl3_list_learners("ids")
to find learners whose fitting procedures support clustered observations, and use sl3_list_learners("cv")
to find learners whose fitting procedures involve cross-validation.
weights
A character indicating which variable (if any) to be used as observation weights, for learners that support that. Use sl3_list_learners("weights")
to find such learners.
offset
A character indicating which variable (if any) to be used as an observation offset, for learners that support that. Use sl3_list_learners("offset")
to find such learners.
nodes
A list of character vectors as nodes. This will override the covariates
, outcome
, id
, weights
, and offset
arguments if specified, serving as an alternative way to specify those arguments.
column_names
A named list of characters that maps between column names in data
and how those variables are referenced in sl3_Task
functions.
drop_missing_outcome
Logical indicating whether to drop outcomes that are missing.
flag
Logical indicating whether to notify the user when there are outcomes that are missing.
folds
An optional origami fold object, as generated by make_folds
, specifying a cross-validation scheme. If NULL
(default), a V-fold cross-validation scheme with V = 10 will be considered for learners that use cross-validation as part of their fitting procedure. Also, if NULL
(default) and id
is specified, then a clustered V-fold cross-validation procedure with 10 folds will be considered. Use sl3_list_learners("cv")
to find learners whose fitting procedures involve cross-validation.
add_interactions(interactions, warn_on_existing = TRUE)
Adds interaction terms to task, returns a task with interaction terms added to covariate list.
interactions
: A list of lists, where each sublist describes one interaction term, listing the variables that comprise it
warn_on_existing
: If TRUE, produce a warning if there is already a column with a name matching this interaction term
add_columns(fit_uuid, new_data, global_cols=FALSE)
Add columns to internal data, returning an updated vector of column_names
fit_uuid
: A uuid character that is used to generate unique internal column names.
This prevents two added columns with the same name overwriting each other, provided they have different fit_uuid.
new_data
: A data.table containing the columns to add
global_cols
: If true, don't use the fit_uuid to make unique column names
next_in_chain(covariates=NULL, outcome=NULL, id=NULL, weights=NULL,
offset=NULL, column_names=NULL, new_nodes=NULL, ...)
Used by learner$chain methods to generate a task with the same underlying data, but redefined nodes.
Most of the parameter values are passed to the sl3_Task
constructor, documented above.
covariates
: An updated covariates character vector
outcome
: An updated outcome character vector
id
: An updated id character value
weights
: An updated weights character value
offset
: An updated offset character value
column_names
: An updated column_names character vector
new_nodes
: An updated list of node names
...
: Other arguments passed to the sl3_Task
constructor for the new task
subset_task(row_index)
Returns a task with rows subsetted using the row_index
index vector
row_index
: An index vector defining the subset
get_data(rows, columns)
Returns a data.table
containing a subset of task data.
rows
: An index vector defining the rows to return
columns
: A character vector of columns to return.
has_node(node_name)
Returns true if the node is defined in the task
node_name
: The name of the node to look for
get_node(node_name, generator_fun=NULL)
Returns a ddta.table with the requested node's data
node_name
: The name of the node to look for
generator_fun
: A function(node_name, n)
that can generate the node if it was not specified in the task.
raw_data
Internal representation of the data
data
Formatted task data
nrow
Number of observations
nodes
A list of node variables
X
a data.table containing the covariates
X
a data.table containing the covariates and an intercept term
Y
a vector containing the outcomes
offsets
a vector containing the offset. Will return an error if the offset wasn't specified on construction
weights
a vector containing the observation weights. If weights aren't specified on construction, weights will default to 1
id
a vector containing the observation units. If the ids aren't specified on construction, id will return seq_len(nrow)
folds
An origami fold object, as generated by make_folds
, specifying a cross-validation scheme
uuid
A unique identifier of this task
column_names
The named list mapping variable names to internal column names
outcome_type
A variable_type
object specifying the type of the outcome