Skorch: A PyTorch wrapper providing an sklearn interface

Table of Contents

Image by Kevin Ku (https://unsplash.com/photos/closeup-photo-of-eyeglasses-w7ZyuGYNpRQ)

Introduction

sklearn and pytorch are without a doubt among the most popular Machine and Deep Learning libraries for Python. For a recent project, I wanted to quickly replace a tree-based learner with a neural network, while reusing as much code as possible. sklearn provides an implementation of neural networks, but it lacks many architectural network features that pytorch offers (e.g., dropout layers). However, the biggest caveat is the lack of GPU acceleration! This is where skorch comes into place: an sklearn compatible neural network library that wraps PyTorch!

The main features advertised by skorch are:

  • A wrapper around pytorch with an sklearn interface,
  • a reduction of boilerplate code in the training loop,
  • custom callbacks,
  • no intention to reinvent the wheel!

Getting started with skorch

As skorch is well-documented, getting started could not be any easier, as the following code demonstrates – note that it is unreasonable to expect any significant speedup from cuda on such a tiny fully-connected network!

We begin by loading the required packages:

1import numpy as np
2from sklearn.datasets import make_classification, make_regression
3from skorch import NeuralNetRegressor
4from skorch.callbacks import Checkpoint, ProgressBar
5from torch_nn import RegressorModule
6import torch
7from torch import nn

Next, we build our RegressorModule that defines our neural network. Any instantiation of this class expects arguments num_inputs, num_outputs and the shape of the hidden_units as an array (e.g., [128, 64] to represent two hidden layers of size 128 and 64, respectively). By default, we use rectified linear units as activation functions. Following that, the network is instantiated as a dense, fully connected network.

 1class RegressorModule(nn.Module):
 2    def __init__(
 3        self,
 4        num_inputs,
 5        hidden_units,
 6        num_outputs,
 7        nonlin=nn.ReLU(),
 8        squeeze_output=False,
 9    ):
10        super(RegressorModule, self).__init__()
11        self.input_units = num_inputs
12        self.hidden_units = hidden_units
13        self.output_units = num_outputs
14        self.nonlin = nonlin
15        self.squeeze_output = squeeze_output
16
17        self.reset_params()
18
19    def reset_params(self):
20        """(Re)set all parameters."""
21        units = [self.input_units]
22        units += self.hidden_units
23        units += [self.output_units]
24
25        sequence = []
26        for u0, u1 in zip(units, units[1:]):
27            sequence.append(nn.Linear(u0, u1))
28            sequence.append(self.nonlin)
29
30        sequence = sequence[:-1]  # no ReLU on output!
31
32        self.sequential = nn.Sequential(*sequence)
33
34    def forward(self, X):
35        X = self.sequential(X)
36        if self.squeeze_output:
37            X = X.squeeze(-1)
38        return X

A simple training function might look as follows. We begin by generating dummy inputs X and labels y and specify the architecture of our fully connected feed-forward network.

This code also demonstrates one of the advantages of using skorch: we can use predefined or custom callbacks. In our case, we add a Checkpoint in order to save a model with the corresponding fn_prefix to a file with in a folder foo when either:

  • An epoch has passed.
  • An epoch has passed in which we saw a new best training loss.
  • An epoch has passed in which we saw a new best validation loss.
 1def train():
 2    # Generate dummy data
 3    X, y = make_regression(
 4        1000, 20, n_informative=10, n_targets=2, random_state=0
 5    )
 6    X, y = X.astype(np.float32), y.astype(np.float32) / 100
 7    architecture = [32, 16]
 8
 9    # Callbacks
10    path = "foo"
11    cp_epoch = Checkpoint(
12        dirname=path,
13        fn_prefix="epoch_",
14        monitor=None,
15        event_name="e",
16    )
17    cp_train = Checkpoint(
18        dirname=path,
19        fn_prefix="train_",
20        monitor="train_loss_best",
21        event_name="t",
22    )
23    cp_valid = Checkpoint(
24        dirname=path,
25        fn_prefix="valid_",
26        monitor="valid_loss_best",
27        event_name="v",
28    )
29    progress = ProgressBar()
30    cb = [cp_epoch, cp_train, cp_valid, progress]
31
32    model = NeuralNetRegressor(
33        module=RegressorModule,
34        module__num_inputs=X.shape[1],
35        module__hidden_units=architecture,
36        module__num_outputs=y.shape[1],
37        criterion=torch.nn.MSELoss,
38        # device="cuda",  # uncomment this to train with CUDA
39        optimizer=torch.optim.Adam,
40        max_epochs=10,
41        batch_size=256,  # default: 128
42        lr=1e-2,  # default: 1e-3
43        # train_split=None, # default: 80/20 train/valid
44        iterator_train__shuffle=True,
45        # iterator_train__num_workers=4,
46        # iterator_valid__num_workers=4,
47        # iterator_valid__shuffle=False,
48        callbacks=cb,
49    )
50
51    torch.manual_seed(0)
52    torch.cuda.manual_seed(0)
53
54    model.fit(X, y)

Apart from a Checkpoint, skorch.callbacks provides several other callbacks and allows us to define our own. This includes, e.g., callbacks to apply EarlyStopping, when a metric of interest does not improve for a given number of epochs, or LRScheduler to dynamically adjust the lr, in case the model stalls. The latter is perfectly suited to provide an interface to torch.optim in order to apply, e.g., ReduceLROnPlateau.

1if __name__ == "__main__":
2
3    train()

Running the code yields the following output that documents the changes in the training and validation loss along with the duration of each epoch. The events e, tand vcorrespond to our checkpoints and document the epochs in which the corresponding monitor improved.

  epoch    e    t    train_loss    v    valid_loss     dur
-------  ---  ---  ------------  ---  ------------  ------
      1    +    +        3.1877    +        2.9380  0.0295
      2    +    +        2.9004    +        2.4576  0.0105
      3    +    +        2.3326    +        1.5936  0.0112
      4    +    +        1.3732    +        0.5787  0.0130
      5    +    +        0.4784    +        0.3349  0.0105
      6    +    +        0.4192             0.3570  0.0120
      7    +    +        0.2749    +        0.1664  0.0135
      8    +    +        0.1456             0.1880  0.0097
      9    +             0.1630    +        0.1354  0.0105
     10    +    +        0.0919    +        0.0892  0.0100

Loading or warmstarting a model

A Checkpoint is also perfectly suited to load and initialize a model. Consider that our previous models were saved to the folder foo. We can resume training from the last Checkpoint by passing a LoadInitState. Note that this

 1    # Previous call at end of train()
 2    # model.fit(X, y)
 3
 4    # Warm start a model
 5    model.warm_start = True
 6    model.fit(X, y)
 7
 8    # Load a model from a Checkpoint
 9    from skorch.callbacks import LoadInitState
10    cb.append(LoadInitState(cp_train))
11
12    model = NeuralNetRegressor(
13        RegressorModule,
14        module__num_inputs=X.shape[1],
15        module__hidden_units=architecture,
16        module__num_outputs=y.shape[1],
17        optimizer=torch.optim.Adam,
18        lr=0.1,
19        callbacks=cb,
20    )
21
22    # Continue training at the last epoch where the
23    #  training loss improved with an lr
24    model.fit(X, y)

Query model weights and bias

In some cases, it may be desirable to query weight matrices W and bias vectors b of each layer. With skorch, this can easily be achieved as follows.

 1    # Previous call to fit() saved a model
 2    #model.fit(X, y)
 3
 4    model = NeuralNetRegressor(
 5        RegressorModule,
 6        module__num_inputs=X.shape[1],
 7        module__hidden_units=architecture,
 8        module__num_outputs=y.shape[1],
 9        optimizer=torch.optim.Adam,
10        lr=0.1,
11        callbacks=cb,
12    )
13
14    # Load the model and delete its history
15    model.initialize()
16    model.trim_for_prediction()
17
18    # Query weight matrices and bias vectors
19    W = []
20    b = []
21    for i, tensor in enumerate(list(model.module_.parameters())):
22        if i % 2 == 0:
23            assert len(tensor.shape) == 2
24            W.append(tensor.mT.detach().numpy())
25        else:
26            assert len(tensor.shape) == 1
27            b.append(tensor.detach().numpy())

Compatability with sklearn API

Because skorch provides compatability with sklearn, any model that was instantitated through skorch can be used in the standard sklearn manner. Applications include a Pipeline or a GridSearchCV.

1from sklearn.pipeline import Pipeline
2from sklearn.preprocessing import StandardScaler
3
4pipe = Pipeline([
5    ('scale', StandardScaler()),
6    ('net', model),
7])
8
9pipe.fit(X, y)
 1from sklearn.model_selection import GridSearchCV
 2
 3# deactivate skorch-internal train-valid split and verbose logging
 4net.set_params(train_split=False, verbose=0)
 5params = {
 6    'lr': [0.01, 0.02],
 7    'max_epochs': [10, 20],
 8    'module__architecture': [[64], [64, 32], [128, 64]],
 9}
10gs = GridSearchCV(net, params, refit=False, cv=3, scoring='r2')
11
12gs.fit(X, y)
13print(gs.best_score_, gs.best_params_)

Conclusion

This concludes a first look at the regressor (classifier) and callbacks modules of skorch! 😎