Live Notebook

You can run this notebook in a live session Binder or view it on Github.

Batch Prediction with PyTorch

[1]:
%matplotlib inline

This example follows Torch’s transfer learning tutorial. We will

  1. Finetune a pretrained convolutional neural network on a specific task (ants vs. bees).

  2. Use a Dask cluster for batch prediction with that model.

The primary focus is using a Dask cluster for batch prediction.

Note that the base environment on the examples.dask.org Binder does not include PyTorch or torchvision. To run this example, you’ll need to run

!conda install -y pytorch-cpu torchvision

which will take a bit of time to run.

Download the data

The PyTorch documentation hosts a small set of data. We’ll download and extract it locally.

[2]:
import urllib.request
import zipfile
[3]:
filename, _ = urllib.request.urlretrieve("https://download.pytorch.org/tutorial/hymenoptera_data.zip", "data.zip")
zipfile.ZipFile(filename).extractall()

The directory looks like

hymenoptera_data/
    train/
        ants/
            0013035.jpg
            ...
            1030023514_aad5c608f9.jpg
        bees/
            1092977343_cb42b38d62.jpg
            ...
            2486729079_62df0920be.jpg

    train/
        ants/
            0013025.jpg
            ...
            1030023514_aad5c606d9.jpg
        bees/
            1092977343_cb42b38e62.jpg
            ...
            2486729079_62df0921be.jpg

Following the tutorial, we’ll finetune the model.

[5]:
import torchvision
from tutorial_helper import (imshow, train_model, visualize_model,
                             dataloaders, class_names, finetune_model)

Finetune the model

Our base model is resnet18. It predicts for 1,000 categories, while ours just predicts 2 (ants or bees). To make this model train quickly on examples.dask.org, we’ll only use a couple of epochs.

[6]:
import dask
[7]:
%%time
model = finetune_model()
Epoch 0/1
----------
train Loss: 0.5446 Acc: 0.7213
val Loss: 0.3002 Acc: 0.8824

Epoch 1/1
----------
train Loss: 0.5793 Acc: 0.7664
val Loss: 0.1800 Acc: 0.9346

Training complete in 0m 55s
Best val Acc: 0.934641
CPU times: user 4min 8s, sys: 2min 39s, total: 6min 48s
Wall time: 55.9 s

Things seem OK on a few random images:

[8]:
visualize_model(model);
../_images/machine-learning_torch-prediction_13_0.png

Batch Prediction with Dask

Now for the main topic: using a pretrained model for batch prediction on a Dask cluster. There are two main complications, that both deal with minimizing the amount of data moved around:

  1. Loading the data on the workers.. We’ll use dask.delayed to load the data on the workers, rather than loading it on the client and sending it to the workers.

  2. PyTorch neural networks are large. We don’t want them in Dask task graphs, and we only want to move them around once.

[9]:
from distributed import Client

client = Client(n_workers=2, threads_per_worker=2)
client
[9]:

Client

Cluster

  • Workers: 2
  • Cores: 4
  • Memory: 31.62 GB

Loading the data on the workers

First, we’ll define a couple helpers to load the data and preprocess it for the neural network. We’ll use dask.delayed here so that the execuation is lazy and happens on the cluster. See the delayed example for more on using dask.delayed.

[10]:
import glob
import toolz
import dask
import dask.array as da
import torch
from torchvision import transforms
from PIL import Image


@dask.delayed
def load(path, fs=__builtins__):
    with fs.open(path, 'rb') as f:
        img = Image.open(f).convert("RGB")
        return img


@dask.delayed
def transform(img):
    trn = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    return trn(img)
[11]:
objs = [load(x) for x in glob.glob("hymenoptera_data/val/*/*.jpg")]

To load the data from cloud storage, say Amazon S3, you would use

import s3fs

fs = s3fs.S3FileSystem(...)
objs = [load(x, fs=fs) for x in fs.glob(...)]

The PyTorch model expects tensors of a specific shape, so let’s transform them.

[12]:
tensors = [transform(x) for x in objs]

And the model expects batches of inputs, so let’s stack a few together.

[13]:
batches = [dask.delayed(torch.stack)(torch.stack)
           for batch in toolz.partition_all(10, tensors)]
batches[:5]
[13]:
[Delayed('stack-0bdfad49-aef0-460f-a715-7470e74a0fdd'),
 Delayed('stack-e7685443-656f-4112-8717-eb386cd8bdf7'),
 Delayed('stack-cf6e47e0-79bc-4d60-b9ce-14da3af2177e'),
 Delayed('stack-fccfe6b7-38cf-4a05-a8c1-2da55194bf0b'),
 Delayed('stack-4aec6506-a711-496b-924f-fb5f2e6ff018')]

Finally, we’ll write a small predict helper to predict the output class (0 or 1).

[14]:
@dask.delayed
def predict(batch, model):
    with torch.no_grad():
        out = model(batch)
        _, predicted = torch.max(out, 1)
        predicted = predicted.numpy()
    return predicted

Moving the model around

PyTorch neural networks are large, so we don’t want to repeat it many times in our task graph (once per batch).

[15]:
import pickle

dask.utils.format_bytes(len(pickle.dumps(model)))
[15]:
'44.80 MB'

Instead, we’ll also wrap the model itself in dask.delayed. This means the model only shows up once in the Dask graph.

[16]:
dmodel = dask.delayed(model)

Now we’ll use the (delayed) predict method to get our predictions.

[17]:
predictions = [predict(batch, dmodel) for batch in batches]
dask.visualize(predictions[:2])
[17]:
../_images/machine-learning_torch-prediction_30_0.png

The visualization is a bit messy, but the large PyTorch model is the box that’s an ancestor of both predict tasks.

Now, we can do the computation, using the Dask cluster to do all the work. Because the dataset we’re working with is small, it’s safe to just use dask.compute to bring the results back to the local Client. For a larger dataset you would want to write to disk or cloud storage or continue processing the predictions on the cluster.

[18]:
predictions = dask.compute(*predictions)
predictions
[18]:
(array([0, 0, 0, 0, 0, 0, 1, 0, 0, 0]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
 array([0, 1, 0, 0, 0, 0, 0, 0, 0, 0]),
 array([0, 0, 1, 0, 0, 0, 0, 1, 0, 0]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 array([1, 1, 1, 1, 1, 0, 1, 1, 1, 1]),
 array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
 array([1, 1, 1, 1, 1, 1, 1, 0, 1, 1]),
 array([1, 1, 1, 1, 1, 1, 0, 1, 1, 1]),
 array([1, 0, 1]))

Summary

This example showed how to do batch prediction on a set of images using PyTorch and Dask. We were careful to load data remotely on the cluster, and to serialize the large neural network only once.