mlr3 Basics - German Credit

visualization classification feature importance german credit data set classification

In this use case, we teach the basics of mlr3 by training different models on the German credit dataset.

Martin Binder , Florian Pfisterer , Michel Lang
03-11-2020

Intro

This is the first part in a serial of tutorials. The other parts of this series can be found here:

We will walk through this tutorial interactively. The text is kept short to be followed in real time.

Prerequisites

Ensure all packages used in this tutorial are installed. This includes the mlr3verse package, as well as other packages for data handling, cleaning and visualization which we are going to use (data.table, ggplot2, rchallenge, and skimr).

Then, load the main packages we are going to use:

Machine Learning Use Case: German Credit Data

The German credit data was originally donated in 1994 by Prof. Dr. Hans Hoffman of the University of Hamburg. A description can be found at the UCI repository. The goal is to classify people by their credit risk (good or bad) using 20 personal, demographic and financial features:

Feature Name Description
age age in years
amount amount asked by applicant
credit_history past credit history of applicant at this bank
duration duration of the credit in months
employment_duration present employment since
foreign_worker is applicant foreign worker?
housing type of apartment rented, owned, for free / no payment
installment_rate installment rate in percentage of disposable income
job current job information
number_credits number of existing credits at this bank
other_debtors other debtors/guarantors present?
other_installment_plans other installment plans the applicant is paying
people_liable number of people being liable to provide maintenance
personal_status_sex combination of sex and personal status of applicant
present_residence present residence since
property properties that applicant has
purpose reason customer is applying for a loan
savings savings accounts/bonds at this bank
status status/balance of checking account at this bank
telephone is there any telephone registered for this customer?

Importing the Data

The dataset we are going to use is a transformed version of this German credit dataset, as provided by the rchallenge package (this transformed dataset was proposed by Ulrike Grömping, with factors instead of dummy variables and corrected features):

data("german", package = "rchallenge")

First, we’ll do a thorough investigation of the dataset.

Exploring the Data

We can get a quick overview of our dataset using R’s summary function:

dim(german)
[1] 1000   21
str(german)
'data.frame':   1000 obs. of  21 variables:
 $ status                 : Factor w/ 4 levels "no checking account",..: 1 1 2 1 1 1 1 1 4 2 ...
 $ duration               : int  18 9 12 12 12 10 8 6 18 24 ...
 $ credit_history         : Factor w/ 5 levels "delay in paying off in the past",..: 5 5 3 5 5 5 5 5 5 3 ...
 $ purpose                : Factor w/ 11 levels "others","car (new)",..: 3 1 10 1 1 1 1 1 4 4 ...
 $ amount                 : int  1049 2799 841 2122 2171 2241 3398 1361 1098 3758 ...
 $ savings                : Factor w/ 5 levels "unknown/no savings account",..: 1 1 2 1 1 1 1 1 1 3 ...
 $ employment_duration    : Factor w/ 5 levels "unemployed","< 1 yr",..: 2 3 4 3 3 2 4 2 1 1 ...
 $ installment_rate       : Ord.factor w/ 4 levels ">= 35"<"25 <= ... < 35"<..: 4 2 2 3 4 1 1 2 4 1 ...
 $ personal_status_sex    : Factor w/ 4 levels "male : divorced/separated",..: 2 3 2 3 3 3 3 3 2 2 ...
 $ other_debtors          : Factor w/ 3 levels "none","co-applicant",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ present_residence      : Ord.factor w/ 4 levels "< 1 yr"<"1 <= ... < 4 yrs"<..: 4 2 4 2 4 3 4 4 4 4 ...
 $ property               : Factor w/ 4 levels "unknown / no property",..: 2 1 1 1 2 1 1 1 3 4 ...
 $ age                    : int  21 36 23 39 38 48 39 40 65 23 ...
 $ other_installment_plans: Factor w/ 3 levels "bank","stores",..: 3 3 3 3 1 3 3 3 3 3 ...
 $ housing                : Factor w/ 3 levels "for free","rent",..: 1 1 1 1 2 1 2 2 2 1 ...
 $ number_credits         : Ord.factor w/ 4 levels "1"<"2-3"<"4-5"<..: 1 2 1 2 2 2 2 1 2 1 ...
 $ job                    : Factor w/ 4 levels "unemployed/unskilled - non-resident",..: 3 3 2 2 2 2 2 2 1 1 ...
 $ people_liable          : Factor w/ 2 levels "3 or more","0 to 2": 2 1 2 1 2 1 2 1 2 2 ...
 $ telephone              : Factor w/ 2 levels "no","yes (under customer name)": 1 1 1 1 1 1 1 1 1 1 ...
 $ foreign_worker         : Factor w/ 2 levels "yes","no": 2 2 2 1 1 1 1 1 2 2 ...
 $ credit_risk            : Factor w/ 2 levels "bad","good": 2 2 2 2 2 2 2 2 2 2 ...

Our dataset has 1000 observations and 21 columns. The variable we want to predict is credit_risk (either good or bad), i.e., we aim to classify people by their credit risk.

We also recommend the skimr package as it creates very well readable and understandable overviews:

skimr::skim(german)
Table 1: Data summary
Name german
Number of rows 1000
Number of columns 21
_______________________
Column type frequency:
factor 18
numeric 3
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
status 0 1 FALSE 4 …: 394, no : 274, …: 269, 0<=: 63
credit_history 0 1 FALSE 5 no : 530, all: 293, exi: 88, cri: 49
purpose 0 1 FALSE 10 fur: 280, oth: 234, car: 181, car: 103
savings 0 1 FALSE 5 unk: 603, …: 183, …: 103, 100: 63
employment_duration 0 1 FALSE 5 1 <: 339, >= : 253, 4 <: 174, < 1: 172
installment_rate 0 1 TRUE 4 < 2: 476, 25 : 231, 20 : 157, >= : 136
personal_status_sex 0 1 FALSE 4 mal: 548, fem: 310, fem: 92, mal: 50
other_debtors 0 1 FALSE 3 non: 907, gua: 52, co-: 41
present_residence 0 1 TRUE 4 >= : 413, 1 <: 308, 4 <: 149, < 1: 130
property 0 1 FALSE 4 bui: 332, unk: 282, car: 232, rea: 154
other_installment_plans 0 1 FALSE 3 non: 814, ban: 139, sto: 47
housing 0 1 FALSE 3 ren: 714, for: 179, own: 107
number_credits 0 1 TRUE 4 1: 633, 2-3: 333, 4-5: 28, >= : 6
job 0 1 FALSE 4 ski: 630, uns: 200, man: 148, une: 22
people_liable 0 1 FALSE 2 0 t: 845, 3 o: 155
telephone 0 1 FALSE 2 no: 596, yes: 404
foreign_worker 0 1 FALSE 2 no: 963, yes: 37
credit_risk 0 1 FALSE 2 goo: 700, bad: 300

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
duration 0 1 20.90 12.06 4 12.0 18.0 24.00 72 ▇▇▂▁▁
amount 0 1 3271.25 2822.75 250 1365.5 2319.5 3972.25 18424 ▇▂▁▁▁
age 0 1 35.54 11.35 19 27.0 33.0 42.00 75 ▇▆▃▁▁

During an exploratory analysis meaningful discoveries could be:

An explanatory analysis is crucial to get a feeling for your data. On the other hand the data can be validated this way. Non-plausible data can be investigated or outliers can be removed.

After feeling confident with the data, we want to do modeling now.

Modeling

Considering how we are going to tackle the problem of classifying the credit risk relates closely to what mlr3 entities we will use.

The typical questions that arise when building a machine learning workflow are:

More systematically in mlr3 they can be expressed via five components:

  1. The Task definition.
  2. The Learner definition.
  3. The training.
  4. The prediction.
  5. The evaluation via one or multiple Measures.

Task Definition

First, we are interested in the target which we want to model. Most supervised machine learning problems are regression or classification problems. However, note that other problems include unsupervised learning or time-to-event data (covered in mlr3proba).

Within mlr3, to distinguish between these problems, we define Tasks. If we want to solve a classification problem, we define a classification task – TaskClassif. For a regression problem, we define a regression task – TaskRegr.

In our case it is clearly our objective to model or predict the binary factor variable credit_risk. Thus, we define a TaskClassif:

task = as_task_classif(german, id = "GermanCredit", target = "credit_risk")

Note that the German credit data is also given as an example task which ships with the mlr3 package. Thus, you actually don’t need to construct it yourself, just call tsk("german_credit") to retrieve the object from the dictionary mlr_tasks.

Learner Definition

After having decided what should be modeled, we need to decide on how. This means we need to decide which learning algorithms, or Learners are appropriate. Using prior knowledge (e.g. knowing that it is a classification task or assuming that the classes are linearly separable) one ends up with one or more suitable Learners.

Many learners can be obtained via the mlr3learners package. Additionally, many learners are provided via the mlr3extralearners package, from GitHub. These two resources combined account for a large fraction of standard learning algorithms. As mlr3 usually only wraps learners from packages, it is even easy to create a formal Learner by yourself. You may find the section about extending mlr3 in the mlr3book very helpful. If you happen to write your own Learner in mlr3, we would be happy if you share it with the mlr3 community.

All available Learners (i.e. all which you have installed from mlr3, mlr3learners, mlr3extralearners, or self-written ones) are registered in the dictionary mlr_learners:

mlr_learners
<DictionaryLearner> with 136 stored values
Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost, classif.cforest, classif.ctree,
  classif.cv_glmnet, classif.debug, classif.earth, classif.extratrees, classif.featureless, classif.fnn,
  classif.gam, classif.gamboost, classif.gausspr, classif.gbm, classif.glmboost, classif.glmnet,
  classif.IBk, classif.J48, classif.JRip, classif.kknn, classif.ksvm, classif.lda, classif.liblinear,
  classif.lightgbm, classif.LMT, classif.log_reg, classif.lssvm, classif.mob, classif.multinom,
  classif.naive_bayes, classif.nnet, classif.OneR, classif.PART, classif.qda, classif.randomForest,
  classif.ranger, classif.rfsrc, classif.rpart, classif.svm, classif.xgboost, clust.agnes, clust.ap,
  clust.cmeans, clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny, clust.featureless,
  clust.ff, clust.hclust, clust.kkmeans, clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
  clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd, dens.kde_ks, dens.locfit,
  dens.logspline, dens.mixed, dens.nonpar, dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
  regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug, regr.earth, regr.extratrees,
  regr.featureless, regr.fnn, regr.gam, regr.gamboost, regr.gausspr, regr.gbm, regr.glm, regr.glmboost,
  regr.glmnet, regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear, regr.lightgbm, regr.lm,
  regr.M5Rules, regr.mars, regr.mob, regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
  regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest, surv.coxboost, surv.coxph,
  surv.coxtime, surv.ctree, surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv, surv.dnnsurv,
  surv.flexible, surv.gamboost, surv.gbm, surv.glmboost, surv.glmnet, surv.kaplan, surv.loghaz,
  surv.mboost, surv.nelson, surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized, surv.ranger,
  surv.rfsrc, surv.rpart, surv.svm, surv.xgboost

For our problem, a suitable learner could be one of the following: Logistic regression, CART, random forest (or many more).

A learner can be initialized with the lrn() function and the name of the learner, e.g., lrn("classif.xxx"). Use ?mlr_learners_xxx to open the help page of a learner named xxx.

For example, a logistic regression can be initialized in the following manner (logistic regression uses R’s glm() function and is provided by the mlr3learners package):

library("mlr3learners")
learner_logreg = lrn("classif.log_reg")
print(learner_logreg)
<LearnerClassifLogReg:classif.log_reg>
* Model: -
* Parameters: list()
* Packages: mlr3, mlr3learners, stats
* Predict Type: response
* Feature types: logical, integer, numeric, character, factor, ordered
* Properties: loglik, twoclass, weights

Training

Training is the procedure, where a model is fitted on the (training) data.

Logistic Regression

We start with the example of the logistic regression. However, you will immediately see that the procedure generalizes to any learner very easily.

An initialized learner can be trained on data using $train():

learner_logreg$train(task)

Typically, in machine learning, one does not use the full data which is available but a subset, the so-called training data.

To efficiently perform a split of the data one could do the following:

train_set = sample(task$row_ids, 0.8 * task$nrow)
test_set = setdiff(task$row_ids, train_set)

80 percent of the data is used for training. The remaining 20 percent are used for evaluation at a subsequent later point in time. train_set is an integer vector referring to the selected rows of the original dataset:

head(train_set)
[1] 201 704 132 223 305 424

In mlr3 the training with a subset of the data can be declared by the additional argument row_ids = train_set:

learner_logreg$train(task, row_ids = train_set)

The fitted model can be accessed via:

learner_logreg$model

Call:  stats::glm(formula = task$formula(), family = "binomial", data = data, 
    model = FALSE)

Coefficients:
                                              (Intercept)                                                        age  
                                                1.6153270                                                 -0.0119829  
                                                   amount     credit_historycritical account/other credits elsewhere  
                                                0.0001423                                                  0.0545112  
credit_historyno credits taken/all credits paid back duly     credit_historyexisting credits paid back duly till now  
                                               -0.7010432                                                 -1.2372112  
    credit_historyall credits at this bank paid back duly                                                   duration  
                                               -1.5176491                                                  0.0243413  
                                employment_duration< 1 yr                        employment_duration1 <= ... < 4 yrs  
                                                0.0964626                                                 -0.0735021  
                      employment_duration4 <= ... < 7 yrs                                employment_duration>= 7 yrs  
                                               -0.5708161                                                  0.2523050  
                                         foreign_workerno                                                housingrent  
                                                1.5302123                                                 -0.6606734  
                                               housingown                                         installment_rate.L  
                                               -1.0938588                                                  0.8780142  
                                       installment_rate.Q                                         installment_rate.C  
                                               -0.1256074                                                 -0.0181840  
                                  jobunskilled - resident                               jobskilled employee/official  
                                               -0.0950214                                                 -0.1709335  
            jobmanager/self-empl./highly qualif. employee                                           number_credits.L  
                                               -0.1830808                                                  0.0869611  
                                         number_credits.Q                                           number_credits.C  
                                                0.1165910                                                  0.2386049  
                                other_debtorsco-applicant                                     other_debtorsguarantor  
                                                0.3447404                                                 -0.8563601  
                            other_installment_plansstores                                other_installment_plansnone  
                                                0.3183363                                                 -0.4949327  
                                      people_liable0 to 2    personal_status_sexfemale : non-single or male : single  
                                               -0.4540724                                                 -0.3834644  
                personal_status_sexmale : married/widowed                         personal_status_sexfemale : single  
                                               -1.0445504                                                 -0.2953724  
                                      present_residence.L                                        present_residence.Q  
                                                0.1418704                                                 -0.3525819  
                                      present_residence.C                                       propertycar or other  
                                                0.2260436                                                  0.1971837  
        propertybuilding soc. savings agr./life insurance                                        propertyreal estate  
                                               -0.0374093                                                  0.6685406  
                                         purposecar (new)                                          purposecar (used)  
                                               -1.6972111                                                 -0.8358362  
                               purposefurniture/equipment                                    purposeradio/television  
                                               -1.0168920                                                 -0.0899452  
                               purposedomestic appliances                                             purposerepairs  
                                               -0.3296450                                                  0.0552592  
                                          purposevacation                                          purposeretraining  
                                               -2.2986114                                                 -1.1332590  
                                          purposebusiness                                       savings... <  100 DM  
                                               -1.1045794                                                 -0.3214354  
                              savings100 <= ... <  500 DM                                savings500 <= ... < 1000 DM  
                                               -0.3771564                                                 -1.2474731  
                                    savings... >= 1000 DM                                           status... < 0 DM  
                                               -1.0896252                                                 -0.2273784  
                                   status0<= ... < 200 DM           status... >= 200 DM / salary for at least 1 year  
                                               -0.9218249                                                 -1.7115087  
                       telephoneyes (under customer name)  
                                               -0.3396867  

Degrees of Freedom: 799 Total (i.e. Null);  745 Residual
Null Deviance:      989 
Residual Deviance: 720  AIC: 830

The stored object is a normal glm object and all its S3 methods work as expected:

class(learner_logreg$model)
[1] "glm" "lm" 
summary(learner_logreg$model)

Call:
stats::glm(formula = task$formula(), family = "binomial", data = data, 
    model = FALSE)

Deviance Residuals: 
    Min       1Q   Median       3Q      Max  
-2.1441  -0.6992  -0.3658   0.6900   2.9703  

Coefficients:
                                                            Estimate Std. Error z value Pr(>|z|)    
(Intercept)                                                1.615e+00  1.296e+00   1.247 0.212560    
age                                                       -1.198e-02  1.030e-02  -1.164 0.244620    
amount                                                     1.423e-04  5.055e-05   2.815 0.004882 ** 
credit_historycritical account/other credits elsewhere     5.451e-02  6.226e-01   0.088 0.930227    
credit_historyno credits taken/all credits paid back duly -7.010e-01  4.768e-01  -1.470 0.141496    
credit_historyexisting credits paid back duly till now    -1.237e+00  5.299e-01  -2.335 0.019559 *  
credit_historyall credits at this bank paid back duly     -1.518e+00  4.847e-01  -3.131 0.001740 ** 
duration                                                   2.434e-02  1.053e-02   2.312 0.020763 *  
employment_duration< 1 yr                                  9.646e-02  5.168e-01   0.187 0.851939    
employment_duration1 <= ... < 4 yrs                       -7.350e-02  5.019e-01  -0.146 0.883566    
employment_duration4 <= ... < 7 yrs                       -5.708e-01  5.404e-01  -1.056 0.290857    
employment_duration>= 7 yrs                                2.523e-01  4.983e-01   0.506 0.612645    
foreign_workerno                                           1.530e+00  6.578e-01   2.326 0.020011 *  
housingrent                                               -6.607e-01  2.607e-01  -2.534 0.011272 *  
housingown                                                -1.094e+00  5.543e-01  -1.973 0.048469 *  
installment_rate.L                                         8.780e-01  2.503e-01   3.508 0.000451 ***
installment_rate.Q                                        -1.256e-01  2.232e-01  -0.563 0.573670    
installment_rate.C                                        -1.818e-02  2.264e-01  -0.080 0.935994    
jobunskilled - resident                                   -9.502e-02  7.708e-01  -0.123 0.901889    
jobskilled employee/official                              -1.709e-01  7.452e-01  -0.229 0.818578    
jobmanager/self-empl./highly qualif. employee             -1.831e-01  7.500e-01  -0.244 0.807151    
number_credits.L                                           8.696e-02  8.255e-01   0.105 0.916099    
number_credits.Q                                           1.166e-01  6.910e-01   0.169 0.866019    
number_credits.C                                           2.386e-01  5.320e-01   0.449 0.653775    
other_debtorsco-applicant                                  3.447e-01  4.504e-01   0.765 0.444011    
other_debtorsguarantor                                    -8.564e-01  4.652e-01  -1.841 0.065629 .  
other_installment_plansstores                              3.183e-01  4.933e-01   0.645 0.518753    
other_installment_plansnone                               -4.949e-01  2.767e-01  -1.789 0.073695 .  
people_liable0 to 2                                       -4.541e-01  2.841e-01  -1.598 0.110019    
personal_status_sexfemale : non-single or male : single   -3.835e-01  4.497e-01  -0.853 0.393860    
personal_status_sexmale : married/widowed                 -1.045e+00  4.458e-01  -2.343 0.019124 *  
personal_status_sexfemale : single                        -2.954e-01  5.199e-01  -0.568 0.569923    
present_residence.L                                        1.419e-01  2.338e-01   0.607 0.544063    
present_residence.Q                                       -3.526e-01  2.225e-01  -1.585 0.113026    
present_residence.C                                        2.260e-01  2.220e-01   1.018 0.308639    
propertycar or other                                       1.972e-01  2.803e-01   0.704 0.481684    
propertybuilding soc. savings agr./life insurance         -3.741e-02  2.622e-01  -0.143 0.886564    
propertyreal estate                                        6.685e-01  5.059e-01   1.321 0.186348    
purposecar (new)                                          -1.697e+00  4.223e-01  -4.019 5.84e-05 ***
purposecar (used)                                         -8.358e-01  2.949e-01  -2.834 0.004591 ** 
purposefurniture/equipment                                -1.017e+00  2.773e-01  -3.667 0.000245 ***
purposeradio/television                                   -8.995e-02  1.010e+00  -0.089 0.929042    
purposedomestic appliances                                -3.296e-01  5.756e-01  -0.573 0.566875    
purposerepairs                                             5.526e-02  4.272e-01   0.129 0.897069    
purposevacation                                           -2.299e+00  1.219e+00  -1.886 0.059362 .  
purposeretraining                                         -1.133e+00  3.913e-01  -2.896 0.003776 ** 
purposebusiness                                           -1.105e+00  8.587e-01  -1.286 0.198305    
savings... <  100 DM                                      -3.214e-01  3.287e-01  -0.978 0.328160    
savings100 <= ... <  500 DM                               -3.772e-01  4.745e-01  -0.795 0.426671    
savings500 <= ... < 1000 DM                               -1.247e+00  5.460e-01  -2.285 0.022323 *  
savings... >= 1000 DM                                     -1.090e+00  2.964e-01  -3.676 0.000237 ***
status... < 0 DM                                          -2.274e-01  2.430e-01  -0.936 0.349419    
status0<= ... < 200 DM                                    -9.218e-01  3.926e-01  -2.348 0.018865 *  
status... >= 200 DM / salary for at least 1 year          -1.712e+00  2.685e-01  -6.374 1.85e-10 ***
telephoneyes (under customer name)                        -3.397e-01  2.279e-01  -1.491 0.136035    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 988.95  on 799  degrees of freedom
Residual deviance: 720.00  on 745  degrees of freedom
AIC: 830

Number of Fisher Scoring iterations: 5

Random Forest

Just like the logistic regression, we could train a random forest instead. We use the fast implementation from the ranger package. For this, we first need to define the learner and then actually train it.

We now additionally supply the importance argument (importance = "permutation"). Doing so, we override the default and let the learner do feature importance determination based on permutation feature importance:

learner_rf = lrn("classif.ranger", importance = "permutation")
learner_rf$train(task, row_ids = train_set)

We can access the importance values using $importance():

learner_rf$importance()
                 status                duration          credit_history                  amount                 purpose 
           0.0329235598            0.0134295392            0.0104092860            0.0094535878            0.0041505318 
                savings        installment_rate                     age other_installment_plans          number_credits 
           0.0041344571            0.0039715347            0.0034731747            0.0029940401            0.0028366404 
          other_debtors                 housing                property       present_residence           people_liable 
           0.0027675592            0.0026135756            0.0025556631            0.0019050440            0.0014788754 
                    job     personal_status_sex     employment_duration               telephone          foreign_worker 
           0.0014238830            0.0011392806            0.0008483567            0.0004325367            0.0002573084 

In order to obtain a plot for the importance values, we convert the importance to a data.table and then process it with ggplot2:

importance = as.data.table(learner_rf$importance(), keep.rownames = TRUE)
colnames(importance) = c("Feature", "Importance")
ggplot(importance, aes(x = reorder(Feature, Importance), y = Importance)) +
  geom_col() + coord_flip() + xlab("")

Prediction

Let’s see what the models predict.

After training a model, the model can be used for prediction. Usually, prediction is the main purpose of machine learning models.

In our case, the model can be used to classify new credit applicants w.r.t. their associated credit risk (good vs. bad) on the basis of the features. Typically, machine learning models predict numeric values. In the regression case this is very natural. For classification, most models predict scores or probabilities. Based on these values, one can derive class predictions.

Predict Classes

First, we directly predict classes:

prediction_logreg = learner_logreg$predict(task, row_ids = test_set)
prediction_rf = learner_rf$predict(task, row_ids = test_set)
prediction_logreg
<PredictionClassif> for 200 observations:
    row_ids truth response
          2  good      bad
          6  good     good
          7  good     good
---                       
        971   bad     good
        976   bad      bad
        979   bad     good
prediction_rf
<PredictionClassif> for 200 observations:
    row_ids truth response
          2  good     good
          6  good     good
          7  good     good
---                       
        971   bad     good
        976   bad      bad
        979   bad     good

The $predict() method returns a Prediction object. It can be converted to a data.table if one wants to use it downstream.

We can also display the prediction results aggregated in a confusion matrix:

prediction_logreg$confusion
        truth
response bad good
    bad   29   16
    good  24  131
prediction_rf$confusion
        truth
response bad good
    bad   27    5
    good  26  142

Predict Probabilities

Most learners may not only predict a class variable (“response”), but also their degree of “belief” / “uncertainty” in a given response. Typically, we achieve this by setting the $predict_type slot of a Learner to "prob". Sometimes this needs to be done before the learner is trained. Alternatively, we can directly create the learner with this option: lrn("classif.log_reg", predict_type = "prob").

learner_logreg$predict_type = "prob"
learner_logreg$predict(task, row_ids = test_set)
<PredictionClassif> for 200 observations:
    row_ids truth response   prob.bad  prob.good
          2  good      bad 0.56529769 0.43470231
          6  good     good 0.12321676 0.87678324
          7  good     good 0.02640308 0.97359692
---                                             
        971   bad     good 0.27334325 0.72665675
        976   bad      bad 0.91190635 0.08809365
        979   bad     good 0.14405093 0.85594907

Note that sometimes one needs to be cautious when dealing with the probability interpretation of the predictions.

Performance Evaluation

To measure the performance of a learner on new unseen data, we usually mimic the scenario of unseen data by splitting up the data into training and test set. The training set is used for training the learner, and the test set is only used for predicting and evaluating the performance of the trained learner. Numerous resampling methods (cross-validation, bootstrap) repeat the splitting process in different ways.

Within mlr3, we need to specify the resampling strategy using the rsmp() function:

resampling = rsmp("holdout", ratio = 2/3)
print(resampling)
<ResamplingHoldout> with 1 iterations
* Instantiated: FALSE
* Parameters: ratio=0.6667

Here, we use “holdout”, a simple train-test split (with just one iteration). We use the resample() function to undertake the resampling calculation:

res = resample(task, learner = learner_logreg, resampling = resampling)
res
<ResampleResult> of 1 iterations
* Task: GermanCredit
* Learner: classif.log_reg
* Warnings: 0 in 0 iterations
* Errors: 0 in 0 iterations

The default score of the measure is included in the $aggregate() slot:

res$aggregate()
classif.ce 
 0.2732733 

The default measure in this scenario is the classification error. Lower is better.

We can easily run different resampling strategies, e.g. repeated holdout ("subsampling"), or cross validation. Most methods perform repeated train/predict cycles on different data subsets and aggregate the result (usually as the mean). Doing this manually would require us to write loops. mlr3 does the job for us:

resampling = rsmp("subsampling", repeats = 10)
rr = resample(task, learner = learner_logreg, resampling = resampling)
rr$aggregate()
classif.ce 
 0.2333333 

Instead, we could also run cross-validation:

resampling = resampling = rsmp("cv", folds = 10)
rr = resample(task, learner = learner_logreg, resampling = resampling)
rr$aggregate()
classif.ce 
     0.245 

mlr3 features scores for many more measures. Here, we apply mlr_measures_classif.fpr for the false positive rate, and mlr_measures_classif.fnr for the false negative rate. Multiple measures can be provided as a list of measures (which can directly be constructed via msrs():

# false positive rate
rr$aggregate(msr("classif.fpr"))
classif.fpr 
  0.1340364 
# false positive rate and false negative
measures = msrs(c("classif.fpr", "classif.fnr"))
rr$aggregate(measures)
classif.fpr classif.fnr 
  0.1340364   0.5051568 

There are a few more resampling methods, and quite a few more measures (implemented in mlr3measures). They are automatically registered in the respective dictionaries:

mlr_resamplings
<DictionaryResampling> with 9 stored values
Keys: bootstrap, custom, custom_cv, cv, holdout, insample, loo, repeated_cv, subsampling
mlr_measures
<DictionaryMeasure> with 87 stored values
Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier, classif.ce, classif.costs,
  classif.dor, classif.fbeta, classif.fdr, classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
  classif.logloss, classif.mbrier, classif.mcc, classif.npv, classif.ppv, classif.prauc, classif.precision,
  classif.recall, classif.sensitivity, classif.specificity, classif.tn, classif.tnr, classif.tp,
  classif.tpr, clust.ch, clust.db, clust.dunn, clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
  regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse, regr.msle,
  regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
  regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi, surv.brier, surv.calib_alpha,
  surv.calib_beta, surv.chambless_auc, surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
  surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2, surv.rmse, surv.schmid,
  surv.song_auc, surv.song_tnr, surv.song_tpr, surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2,
  time_both, time_predict, time_train

To get help on a resampling method, use ?mlr_resamplings_xxx, for a measure do ?mlr_measures_xxx. You can also browse the mlr3 reference online.

Note that some measures, for example AUC, require the prediction of probabilities.

Performance Comparison and Benchmarks

We could compare Learners by evaluating resample() for each of them manually. However, benchmark() automatically performs resampling evaluations for multiple learners and tasks. benchmark_grid() creates fully crossed designs: Multiple Learners for multiple Tasks are compared w.r.t. multiple Resamplings.

learners = lrns(c("classif.log_reg", "classif.ranger"), predict_type = "prob")
grid = benchmark_grid(
  tasks = task,
  learners = learners,
  resamplings = rsmp("cv", folds = 10)
)
bmr = benchmark(grid)

Careful, large benchmarks may take a long time! This one should take less than a minute, however. In general, we want to use parallelization to speed things up on multi-core machines. For parallelization, mlr3 relies on the future package:

#future::plan("multiprocess") # uncomment for parallelization

In the benchmark we can compare different measures. Here, we look at the misclassification rate and the AUC:

measures = msrs(c("classif.ce", "classif.auc"))
performances = bmr$aggregate(measures)
performances[, c("learner_id", "classif.ce", "classif.auc")]
        learner_id classif.ce classif.auc
1: classif.log_reg      0.255   0.7840177
2:  classif.ranger      0.229   0.8028581

We see that the two models perform very similarly.

Deviating from hyperparameters defaults

The previously shown techniques build the backbone of a mlr3-featured machine learning workflow. However, in most cases one would never proceed in the way we did. While many R packages have carefully selected default settings, they will not perform optimally in any scenario. Typically, we can select the values of such hyperparameters. The (hyper)parameters of a Learner can be accessed and set via its ParamSet $param_set:

learner_rf$param_set
<ParamSet>
                              id    class lower upper nlevels        default    parents       value
 1:                        alpha ParamDbl  -Inf   Inf     Inf            0.5                       
 2:       always.split.variables ParamUty    NA    NA     Inf <NoDefault[3]>                       
 3:                class.weights ParamUty    NA    NA     Inf                                      
 4:                      holdout ParamLgl    NA    NA       2          FALSE                       
 5:                   importance ParamFct    NA    NA       4 <NoDefault[3]>            permutation
 6:                   keep.inbag ParamLgl    NA    NA       2          FALSE                       
 7:                    max.depth ParamInt     0   Inf     Inf                                      
 8:                min.node.size ParamInt     1   Inf     Inf              1                       
 9:                     min.prop ParamDbl  -Inf   Inf     Inf            0.1                       
10:                      minprop ParamDbl  -Inf   Inf     Inf            0.1                       
11:                         mtry ParamInt     1   Inf     Inf <NoDefault[3]>                       
12:                   mtry.ratio ParamDbl     0     1     Inf <NoDefault[3]>                       
13:            num.random.splits ParamInt     1   Inf     Inf              1  splitrule            
14:                  num.threads ParamInt     1   Inf     Inf              1                      1
15:                    num.trees ParamInt     1   Inf     Inf            500                       
16:                    oob.error ParamLgl    NA    NA       2           TRUE                       
17:        regularization.factor ParamUty    NA    NA     Inf              1                       
18:      regularization.usedepth ParamLgl    NA    NA       2          FALSE                       
19:                      replace ParamLgl    NA    NA       2           TRUE                       
20:    respect.unordered.factors ParamFct    NA    NA       3         ignore                       
21:              sample.fraction ParamDbl     0     1     Inf <NoDefault[3]>                       
22:                  save.memory ParamLgl    NA    NA       2          FALSE                       
23: scale.permutation.importance ParamLgl    NA    NA       2          FALSE importance            
24:                    se.method ParamFct    NA    NA       2        infjack                       
25:                         seed ParamInt  -Inf   Inf     Inf                                      
26:         split.select.weights ParamUty    NA    NA     Inf                                      
27:                    splitrule ParamFct    NA    NA       2           gini                       
28:                      verbose ParamLgl    NA    NA       2           TRUE                       
29:                 write.forest ParamLgl    NA    NA       2           TRUE                       
                              id    class lower upper nlevels        default    parents       value
learner_rf$param_set$values = list(verbose = FALSE)

We can choose parameters for our learners in two distinct manners. If we have prior knowledge on how the learner should be (hyper-)parameterized, the way to go would be manually entering the parameters in the parameter set. In most cases, however, we would want to tune the learner so that it can search “good” model configurations itself. For now, we only want to compare a few models.

To get an idea on which parameters can be manipulated, we can investigate the parameters of the original package version or look into the parameter set of the learner:

## ?ranger::ranger
as.data.table(learner_rf$param_set)[, .(id, class, lower, upper)]
                              id    class lower upper
 1:                        alpha ParamDbl  -Inf   Inf
 2:       always.split.variables ParamUty    NA    NA
 3:                class.weights ParamUty    NA    NA
 4:                      holdout ParamLgl    NA    NA
 5:                   importance ParamFct    NA    NA
 6:                   keep.inbag ParamLgl    NA    NA
 7:                    max.depth ParamInt     0   Inf
 8:                min.node.size ParamInt     1   Inf
 9:                     min.prop ParamDbl  -Inf   Inf
10:                      minprop ParamDbl  -Inf   Inf
11:                         mtry ParamInt     1   Inf
12:                   mtry.ratio ParamDbl     0     1
13:            num.random.splits ParamInt     1   Inf
14:                  num.threads ParamInt     1   Inf
15:                    num.trees ParamInt     1   Inf
16:                    oob.error ParamLgl    NA    NA
17:        regularization.factor ParamUty    NA    NA
18:      regularization.usedepth ParamLgl    NA    NA
19:                      replace ParamLgl    NA    NA
20:    respect.unordered.factors ParamFct    NA    NA
21:              sample.fraction ParamDbl     0     1
22:                  save.memory ParamLgl    NA    NA
23: scale.permutation.importance ParamLgl    NA    NA
24:                    se.method ParamFct    NA    NA
25:                         seed ParamInt  -Inf   Inf
26:         split.select.weights ParamUty    NA    NA
27:                    splitrule ParamFct    NA    NA
28:                      verbose ParamLgl    NA    NA
29:                 write.forest ParamLgl    NA    NA
                              id    class lower upper

For the random forest two meaningful parameters which steer model complexity are num.trees and mtry. num.trees defaults to 500 and mtry to floor(sqrt(ncol(data) - 1)), in our case 4.

In the following we aim to train three different learners:

  1. The default random forest.
  2. A random forest with low num.trees and low mtry.
  3. A random forest with high num.trees and high mtry.

We will benchmark their performance on the German credit dataset. For this we construct the three learners and set the parameters accordingly:

rf_med = lrn("classif.ranger", id = "med", predict_type = "prob")

rf_low = lrn("classif.ranger", id = "low", predict_type = "prob",
  num.trees = 5, mtry = 2)

rf_high = lrn("classif.ranger", id = "high", predict_type = "prob",
  num.trees = 1000, mtry = 11)

Once the learners are defined, we can benchmark them:

learners = list(rf_low, rf_med, rf_high)
grid = benchmark_grid(
  tasks = task,
  learners = learners,
  resamplings = rsmp("cv", folds = 10)
)
bmr = benchmark(grid)
print(bmr)
<BenchmarkResult> of 30 rows with 3 resampling runs
 nr      task_id learner_id resampling_id iters warnings errors
  1 GermanCredit        low            cv    10        0      0
  2 GermanCredit        med            cv    10        0      0
  3 GermanCredit       high            cv    10        0      0

We compare misclassification rate and AUC again:

measures = msrs(c("classif.ce", "classif.auc"))
performances = bmr$aggregate(measures)
performances[, .(learner_id, classif.ce, classif.auc)]
   learner_id classif.ce classif.auc
1:        low      0.246   0.7593290
2:        med      0.240   0.7985826
3:       high      0.236   0.7948181
autoplot(bmr)

The “low” settings seem to underfit a bit, the “high” setting is comparable to the default setting “med”.

Outlook

This tutorial was a detailed introduction to machine learning workflows within mlr3. Having followed this tutorial you should be able to run your first models yourself. Next to that we spiked into performance evaluation and benchmarking. Furthermore, we showed how to customize learners.

The next parts of the tutorial will go more into depth into additional mlr3 topics:

Appendix

Tips

mlr_tasks
<DictionaryTask> with 27 stored values
Keys: actg, bike_sharing, boston_housing, breast_cancer, faithful, gbcs, german_credit, grace, ilpd, iris,
  kc_housing, lung, moneyball, mtcars, optdigits, penguins, pima, precip, rats, sonar, spam, titanic,
  unemployment, usarrests, whas, wine, zoo
mlr_learners
<DictionaryLearner> with 136 stored values
Keys: classif.AdaBoostM1, classif.bart, classif.C50, classif.catboost, classif.cforest, classif.ctree,
  classif.cv_glmnet, classif.debug, classif.earth, classif.extratrees, classif.featureless, classif.fnn,
  classif.gam, classif.gamboost, classif.gausspr, classif.gbm, classif.glmboost, classif.glmnet,
  classif.IBk, classif.J48, classif.JRip, classif.kknn, classif.ksvm, classif.lda, classif.liblinear,
  classif.lightgbm, classif.LMT, classif.log_reg, classif.lssvm, classif.mob, classif.multinom,
  classif.naive_bayes, classif.nnet, classif.OneR, classif.PART, classif.qda, classif.randomForest,
  classif.ranger, classif.rfsrc, classif.rpart, classif.svm, classif.xgboost, clust.agnes, clust.ap,
  clust.cmeans, clust.cobweb, clust.dbscan, clust.diana, clust.em, clust.fanny, clust.featureless,
  clust.ff, clust.hclust, clust.kkmeans, clust.kmeans, clust.MBatchKMeans, clust.meanshift, clust.pam,
  clust.SimpleKMeans, clust.xmeans, dens.hist, dens.kde, dens.kde_kd, dens.kde_ks, dens.locfit,
  dens.logspline, dens.mixed, dens.nonpar, dens.pen, dens.plug, dens.spline, regr.bart, regr.catboost,
  regr.cforest, regr.ctree, regr.cubist, regr.cv_glmnet, regr.debug, regr.earth, regr.extratrees,
  regr.featureless, regr.fnn, regr.gam, regr.gamboost, regr.gausspr, regr.gbm, regr.glm, regr.glmboost,
  regr.glmnet, regr.IBk, regr.kknn, regr.km, regr.ksvm, regr.liblinear, regr.lightgbm, regr.lm,
  regr.M5Rules, regr.mars, regr.mob, regr.randomForest, regr.ranger, regr.rfsrc, regr.rpart, regr.rvm,
  regr.svm, regr.xgboost, surv.akritas, surv.blackboost, surv.cforest, surv.coxboost, surv.coxph,
  surv.coxtime, surv.ctree, surv.cv_coxboost, surv.cv_glmnet, surv.deephit, surv.deepsurv, surv.dnnsurv,
  surv.flexible, surv.gamboost, surv.gbm, surv.glmboost, surv.glmnet, surv.kaplan, surv.loghaz,
  surv.mboost, surv.nelson, surv.obliqueRSF, surv.parametric, surv.pchazard, surv.penalized, surv.ranger,
  surv.rfsrc, surv.rpart, surv.svm, surv.xgboost
mlr_resamplings
<DictionaryResampling> with 9 stored values
Keys: bootstrap, custom, custom_cv, cv, holdout, insample, loo, repeated_cv, subsampling
mlr_measures
<DictionaryMeasure> with 87 stored values
Keys: aic, bic, classif.acc, classif.auc, classif.bacc, classif.bbrier, classif.ce, classif.costs,
  classif.dor, classif.fbeta, classif.fdr, classif.fn, classif.fnr, classif.fomr, classif.fp, classif.fpr,
  classif.logloss, classif.mbrier, classif.mcc, classif.npv, classif.ppv, classif.prauc, classif.precision,
  classif.recall, classif.sensitivity, classif.specificity, classif.tn, classif.tnr, classif.tp,
  classif.tpr, clust.ch, clust.db, clust.dunn, clust.silhouette, clust.wss, debug, dens.logloss, oob_error,
  regr.bias, regr.ktau, regr.mae, regr.mape, regr.maxae, regr.medae, regr.medse, regr.mse, regr.msle,
  regr.pbias, regr.rae, regr.rmse, regr.rmsle, regr.rrse, regr.rse, regr.rsq, regr.sae, regr.smape,
  regr.srho, regr.sse, selected_features, sim.jaccard, sim.phi, surv.brier, surv.calib_alpha,
  surv.calib_beta, surv.chambless_auc, surv.cindex, surv.dcalib, surv.graf, surv.hung_auc, surv.intlogloss,
  surv.logloss, surv.mae, surv.mse, surv.nagelk_r2, surv.oquigley_r2, surv.rmse, surv.schmid,
  surv.song_auc, surv.song_tnr, surv.song_tpr, surv.uno_auc, surv.uno_tnr, surv.uno_tpr, surv.xu_r2,
  time_both, time_predict, time_train
names(prediction_rf)
 [1] ".__enclos_env__" "confusion"       "prob"            "response"        "missing"         "truth"          
 [7] "row_ids"         "man"             "predict_types"   "task_properties" "task_type"       "data"           
[13] "set_threshold"   "initialize"      "clone"           "filter"          "score"           "help"           
[19] "print"           "format"         
class(prediction_rf)
[1] "PredictionClassif" "Prediction"        "R6"               

Citation

For attribution, please cite this work as

Binder, et al. (2020, March 11). mlr3gallery: mlr3 Basics - German Credit. Retrieved from https://mlr3gallery.mlr-org.com/posts/2020-03-11-basics-german-credit/

BibTeX citation

@misc{binder2020mlr3,
  author = {Binder, Martin and Pfisterer, Florian and Lang, Michel},
  title = {mlr3gallery: mlr3 Basics - German Credit},
  url = {https://mlr3gallery.mlr-org.com/posts/2020-03-11-basics-german-credit/},
  year = {2020}
}