mlr3 basics on “iris” - Hello World!

mlr3 basics

Basic ML operations on iris: Train, predict, score, resample and benchmark. A simple, hands-on intro to mlr3.

Bernd Bischl
03-18-2020

Goals and Prerequisites

This use case shows how to use the basic mlr3 package on the iris task, so it’s our “Hello World” example. It assumes no prior knowledge in ML or mlr3. You can find most of the content here also in the mlr3book in a more detailed way. Hence we will not make a lot of general comments, but keep it hands-on and short.

The following operations are shown:

Loading basic packages

# tasks, train, predict, resample, benchmark
library("mlr3")
# about a dozen reasonable learners
library("mlr3learners")

Creating tasks and learners

Let’s work on the canonical, simple iris data set, and try out some ML algorithms. We will start by using a decision tree with default settings.

# creates mlr3 task from scratch, from a data.frame
# 'target' names the column in the dataset we want to learn to predict
task = TaskClassif$new(id = "iris", backend = iris, target = "Species")
# in this case we could also take the iris example from mlr3's dictionary of shipped example tasks
# 2 equivalent calls to create a task. The second is just sugar for the user.
task = mlr_tasks$get("iris")
task = tsk("iris")
print(task)
<TaskClassif:iris> (150 x 5)
* Target: Species
* Properties: multiclass
* Features (4):
  - dbl (4): Petal.Length, Petal.Width, Sepal.Length, Sepal.Width
# create learner from dictionary of mlr3learners
# 2 equivalent calls:
learner1 = mlr_learners$get("classif.rpart")
learner1 = lrn("classif.rpart")
print(learner1)
<LearnerClassifRpart:classif.rpart>
* Model: -
* Parameters: xval=0
* Packages: rpart
* Predict Type: response
* Feature types: logical, integer, numeric, factor, ordered
* Properties: importance, missings, multiclass, selected_features,
  twoclass, weights

Train and predict

Now the usual ML operations: Train on some observations, predict on others.

# train learner on subset of task
learner1$train(task, row_ids = 1:120)
# this is what the decision tree looks like
print(learner1$model)
n= 120 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 120 70 setosa (0.41666667 0.41666667 0.16666667)  
  2) Petal.Length< 2.45 50  0 setosa (1.00000000 0.00000000 0.00000000) *
  3) Petal.Length>=2.45 70 20 versicolor (0.00000000 0.71428571 0.28571429)  
    6) Petal.Length< 4.95 49  1 versicolor (0.00000000 0.97959184 0.02040816) *
    7) Petal.Length>=4.95 21  2 virginica (0.00000000 0.09523810 0.90476190) *
# predict using observations from task
preds = learner1$predict(task, row_ids = 121:150)
# predict using "new" observations from an external data.frame
preds = learner1$predict_newdata(newdata = iris[121:150, ])
print(preds)
<PredictionClassif> for 30 observations:
    row_id     truth   response
         1 virginica  virginica
         2 virginica versicolor
         3 virginica  virginica
---                            
        28 virginica  virginica
        29 virginica  virginica
        30 virginica  virginica

Evaluation

Let’s score our prediction object with some metrics. And take a deeper look by inspecting the confusion matrix.

head(as.data.table(mlr_measures))
              key task_type     packages predict_type task_properties
1:    classif.acc   classif mlr3measures     response                
2:    classif.auc   classif mlr3measures         prob        twoclass
3:   classif.bacc   classif mlr3measures     response                
4: classif.bbrier   classif mlr3measures         prob        twoclass
5:     classif.ce   classif mlr3measures     response                
6:  classif.costs   classif                  response                
s = preds$score(msr("classif.acc"))
print(s)
classif.acc 
  0.8333333 
s = preds$score(msrs(c("classif.acc", "classif.ce")))
print(s)
classif.acc  classif.ce 
  0.8333333   0.1666667 
cm = preds$confusion
print(cm)
            truth
response     setosa versicolor virginica
  setosa          0          0         0
  versicolor      0          0         5
  virginica       0          0        25

Changing hyperpars

The learner contains information about all parameters that can be configured, including data type, constraints, defaults, etc. We can change the hyperparameters either during construction of later through an active binding.

print(learner1$param_set)
<ParamSet>
                id    class lower upper      levels        default value
 1:       minsplit ParamInt     1   Inf                         20      
 2:      minbucket ParamInt     1   Inf             <NoDefault[3]>      
 3:             cp ParamDbl     0     1                       0.01      
 4:     maxcompete ParamInt     0   Inf                          4      
 5:   maxsurrogate ParamInt     0   Inf                          5      
 6:       maxdepth ParamInt     1    30                         30      
 7:   usesurrogate ParamInt     0     2                          2      
 8: surrogatestyle ParamInt     0     1                          0      
 9:           xval ParamInt     0   Inf                         10     0
10:     keep_model ParamLgl    NA    NA  TRUE,FALSE          FALSE      
learner2 = lrn("classif.rpart", predict_type = "prob", minsplit = 50)
learner2$param_set$values$minsplit = 50

Resampling

Resampling simply repeats the train-predict-score loop and collects all results in a nice data.table.

cv10 = rsmp("cv", folds = 10)
r = resample(task, learner1, cv10)
print(r)
<ResampleResult> of 10 iterations
* Task: iris
* Learner: classif.rpart
* Warnings: 0 in 0 iterations
* Errors: 0 in 0 iterations
r$score(msrs(c("classif.acc", "classif.ce")))
                 task task_id                   learner    learner_id
 1: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
 2: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
 3: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
 4: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
 5: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
 6: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
 7: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
 8: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
 9: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
10: <TaskClassif[45]>    iris <LearnerClassifRpart[33]> classif.rpart
            resampling resampling_id iteration              prediction
 1: <ResamplingCV[19]>            cv         1 <PredictionClassif[19]>
 2: <ResamplingCV[19]>            cv         2 <PredictionClassif[19]>
 3: <ResamplingCV[19]>            cv         3 <PredictionClassif[19]>
 4: <ResamplingCV[19]>            cv         4 <PredictionClassif[19]>
 5: <ResamplingCV[19]>            cv         5 <PredictionClassif[19]>
 6: <ResamplingCV[19]>            cv         6 <PredictionClassif[19]>
 7: <ResamplingCV[19]>            cv         7 <PredictionClassif[19]>
 8: <ResamplingCV[19]>            cv         8 <PredictionClassif[19]>
 9: <ResamplingCV[19]>            cv         9 <PredictionClassif[19]>
10: <ResamplingCV[19]>            cv        10 <PredictionClassif[19]>
    classif.acc classif.ce
 1:   0.9333333 0.06666667
 2:   1.0000000 0.00000000
 3:   0.8666667 0.13333333
 4:   1.0000000 0.00000000
 5:   0.9333333 0.06666667
 6:   0.8666667 0.13333333
 7:   1.0000000 0.00000000
 8:   1.0000000 0.00000000
 9:   0.9333333 0.06666667
10:   0.8666667 0.13333333
print(r$data)
<ResultData>
  Public:
    as_data_table: function (view = NULL, reassemble_learners = TRUE, convert_predictions = TRUE, 
    clone: function (deep = FALSE) 
    combine: function (rdata) 
    data: list
    initialize: function (data = NULL) 
    iterations: function (view = NULL) 
    learners: function (view = NULL, states = TRUE, reassemble = TRUE) 
    logs: function (view = NULL, condition) 
    prediction: function (view = NULL, predict_sets = "test") 
    predictions: function (view = NULL, predict_sets = "test") 
    resamplings: function (view = NULL) 
    sweep: function () 
    task_type: active binding
    tasks: function (view = NULL, reassemble = TRUE) 
    uhashes: function (view = NULL) 
  Private:
    deep_clone: function (name, value) 
    get_view_index: function (view) 
# get all predictions nicely concatenated in a table
preds = r$prediction()
print(preds)
<PredictionClassif> for 150 observations:
    row_id      truth   response
        22     setosa     setosa
        29     setosa     setosa
        32     setosa     setosa
---                             
        96 versicolor versicolor
       138  virginica  virginica
       145  virginica  virginica
cm = preds$confusion
print(cm)
            truth
response     setosa versicolor virginica
  setosa         50          0         0
  versicolor      0         46         5
  virginica       0          4        45

Populating the learner dictionary

mlr3learners ships out with a dozen different popular learners. We can list them from the dictionary. If we want more, we can load an extension package from mlr3’s learner-org on GitHub. Note how after the install the dictionary increases in size.

head(as.data.table(mlr_learners)[, c("key", "packages")])
                   key packages
1:   classif.cv_glmnet   glmnet
2:       classif.debug         
3: classif.featureless         
4:      classif.glmnet   glmnet
5:        classif.kknn     kknn
6:         classif.lda     MASS
# remotes::install_github("mlr3learners/mlr3learners.randomforest")
library(mlr3learners.randomforest)
print(as.data.table(mlr_learners)[, c("key", "packages")])
                     key               packages
 1:    classif.cv_glmnet                 glmnet
 2:        classif.debug                       
 3:  classif.featureless                       
 4:       classif.glmnet                 glmnet
 5:         classif.kknn                   kknn
 6:          classif.lda                   MASS
 7:      classif.log_reg                  stats
 8:     classif.multinom                   nnet
 9:  classif.naive_bayes                  e1071
10:          classif.qda                   MASS
11: classif.randomForest           randomForest
12:       classif.ranger                 ranger
13:        classif.rpart                  rpart
14:          classif.svm                  e1071
15:      classif.xgboost                xgboost
16:            dens.hist                 distr6
17:             dens.kde                 distr6
18:       regr.cv_glmnet                 glmnet
19:     regr.featureless                  stats
20:          regr.glmnet                 glmnet
21:            regr.kknn                   kknn
22:              regr.km            DiceKriging
23:              regr.lm                  stats
24:    regr.randomForest           randomForest
25:          regr.ranger                 ranger
26:           regr.rpart                  rpart
27:             regr.svm                  e1071
28:         regr.xgboost                xgboost
29:           surv.coxph survival,distr6,pracma
30:       surv.cv_glmnet                 glmnet
31:          surv.glmnet                 glmnet
32:          surv.kaplan survival,distr6,pracma
33:          surv.ranger                 ranger
34:           surv.rpart  rpart,distr6,survival
35:         surv.xgboost                xgboost
                     key               packages

Benchmarking multiple learners

The benchmark() function can conveniently compare learners on the same dataset(s).

learners = list(learner1, learner2, lrn("classif.randomForest"))
bm_grid = benchmark_grid(task, learners, cv10)
bm = benchmark(bm_grid)
print(bm)
<BenchmarkResult> of 30 rows with 3 resampling runs
 nr task_id           learner_id resampling_id iters warnings errors
  1    iris        classif.rpart            cv    10        0      0
  2    iris        classif.rpart            cv    10        0      0
  3    iris classif.randomForest            cv    10        0      0
print(bm$aggregate(measures = msrs(c("classif.acc", "classif.ce"))))
   nr      resample_result task_id           learner_id resampling_id iters
1:  1 <ResampleResult[21]>    iris        classif.rpart            cv    10
2:  2 <ResampleResult[21]>    iris        classif.rpart            cv    10
3:  3 <ResampleResult[21]>    iris classif.randomForest            cv    10
   classif.acc classif.ce
1:   0.9266667 0.07333333
2:   0.9333333 0.06666667
3:   0.9533333 0.04666667

Conclusion

We left out a lot of details and other features. If you want to know more, read the mlr3book and the documentation of the mentioned packages.

Citation

For attribution, please cite this work as

Bischl (2020, March 18). mlr3gallery: mlr3 basics on "iris" - Hello World!. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-03-18-iris-mlr3-basics/

BibTeX citation

@misc{mlr3-basics-iris,
  author = {Bischl, Bernd},
  title = {mlr3gallery: mlr3 basics on "iris" - Hello World!},
  url = {https://mlr3gallery.mlr-org.com/posts/2020-03-18-iris-mlr3-basics/},
  year = {2020}
}