Live Notebook

You can run this notebook in a live session Binder or view it on Github.

Batch Prediction with PyTorch

[ ]:
%matplotlib inline

This example follows Torch’s transfer learning tutorial. We will

  1. Finetune a pretrained convolutional neural network on a specific task (ants vs. bees).

  2. Use a Dask cluster for batch prediction with that model.

Note: The dependencies for this example are not installed by default in the Binder environment. You’ll need to execute

!conda install torchvision pytorch-cpu

in a cell to install the necessary packages.

The primary focus is using a Dask cluster for batch prediction.

Download the data

The PyTorch documentation hosts a small set of data. We’ll download and extract it locally.

[ ]:
import urllib.request
import zipfile
[ ]:
filename, _ = urllib.request.urlretrieve("", "")

The directory looks like



Following the tutorial, we’ll finetune the model.

[ ]:
import torchvision
from tutorial_helper import (imshow, train_model, visualize_model,
                             dataloaders, class_names, finetune_model)

Finetune the model

Our base model is resnet18. It predicts for 1,000 categories, while ours just predicts 2 (ants or bees). To make this model train quickly on, we’ll only use a couple of epochs.

[ ]:
import dask
[ ]:
model = finetune_model()

Things seem OK on a few random images:

[ ]:

Batch Prediction with Dask

Now for the main topic: using a pretrained model for batch prediction on a Dask cluster. There are two main complications, that both deal with minimizing the amount of data moved around:

  1. Loading the data on the workers.. We’ll use dask.delayed to load the data on the workers, rather than loading it on the client and sending it to the workers.

  2. PyTorch neural networks are large. We don’t want them in Dask task graphs, and we only want to move them around once.

[ ]:
from distributed import Client

client = Client(n_workers=2, threads_per_worker=2)

Loading the data on the workers

First, we’ll define a couple helpers to load the data and preprocess it for the neural network. We’ll use dask.delayed here so that the execuation is lazy and happens on the cluster. See the delayed example for more on using dask.delayed.

[ ]:
import glob
import toolz
import dask
import dask.array as da
import torch
from torchvision import transforms
from PIL import Image

def load(path, fs=__builtins__):
    with, 'rb') as f:
        img ="RGB")
        return img

def transform(img):
    trn = transforms.Compose([
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    return trn(img)
[ ]:
objs = [load(x) for x in glob.glob("hymenoptera_data/val/*/*.jpg")]

To load the data from cloud storage, say Amazon S3, you would use

import s3fs

fs = s3fs.S3FileSystem(...)
objs = [load(x, fs=fs) for x in fs.glob(...)]

The PyTorch model expects tensors of a specific shape, so let’s transform them.

[ ]:
tensors = [transform(x) for x in objs]

And the model expects batches of inputs, so let’s stack a few together.

[ ]:
batches = [dask.delayed(torch.stack)(batch)
           for batch in toolz.partition_all(10, tensors)]

Finally, we’ll write a small predict helper to predict the output class (0 or 1).

[ ]:
def predict(batch, model):
    with torch.no_grad():
        out = model(batch)
        _, predicted = torch.max(out, 1)
        predicted = predicted.numpy()
    return predicted

Moving the model around

PyTorch neural networks are large, so we don’t want to repeat it many times in our task graph (once per batch).

[ ]:
import pickle


Instead, we’ll also wrap the model itself in dask.delayed. This means the model only shows up once in the Dask graph.

Additionally, since we performed fine-tuning in the above (and that runs on a GPU if its available), we should move the model back to the CPU.

[ ]:
dmodel = dask.delayed(model.cpu()) # ensuring model is on the CPU

Now we’ll use the (delayed) predict method to get our predictions.

[ ]:
predictions = [predict(batch, dmodel) for batch in batches]

The visualization is a bit messy, but the large PyTorch model is the box that’s an ancestor of both predict tasks.

Now, we can do the computation, using the Dask cluster to do all the work. Because the dataset we’re working with is small, it’s safe to just use dask.compute to bring the results back to the local Client. For a larger dataset you would want to write to disk or cloud storage or continue processing the predictions on the cluster.

[ ]:
predictions = dask.compute(*predictions)


This example showed how to do batch prediction on a set of images using PyTorch and Dask. We were careful to load data remotely on the cluster, and to serialize the large neural network only once.