Batch Prediction with PyTorch
Batch Prediction with PyTorch¶
This example follows Torch’s transfer learning tutorial. We will
Finetune a pretrained convolutional neural network on a specific task (ants vs. bees).
Use a Dask cluster for batch prediction with that model.
Note: The dependencies for this example are not installed by default in the Binder environment. You’ll need to execute
!conda install torchvision pytorch-cpu
in a cell to install the necessary packages.
The primary focus is using a Dask cluster for batch prediction.
Download the data¶
The PyTorch documentation hosts a small set of data. We’ll download and extract it locally.
import urllib.request import zipfile
filename, _ = urllib.request.urlretrieve("https://download.pytorch.org/tutorial/hymenoptera_data.zip", "data.zip") zipfile.ZipFile(filename).extractall()
The directory looks like
hymenoptera_data/ train/ ants/ 0013035.jpg ... 1030023514_aad5c608f9.jpg bees/ 1092977343_cb42b38d62.jpg ... 2486729079_62df0920be.jpg train/ ants/ 0013025.jpg ... 1030023514_aad5c606d9.jpg bees/ 1092977343_cb42b38e62.jpg ... 2486729079_62df0921be.jpg
Following the tutorial, we’ll finetune the model.
import torchvision from tutorial_helper import (imshow, train_model, visualize_model, dataloaders, class_names, finetune_model)
Finetune the model¶
Our base model is resnet18. It predicts for 1,000 categories, while ours just predicts 2 (ants or bees). To make this model train quickly on examples.dask.org, we’ll only use a couple of epochs.
%%time model = finetune_model()
Things seem OK on a few random images:
Batch Prediction with Dask¶
Now for the main topic: using a pretrained model for batch prediction on a Dask cluster. There are two main complications, that both deal with minimizing the amount of data moved around:
Loading the data on the workers.. We’ll use
dask.delayedto load the data on the workers, rather than loading it on the client and sending it to the workers.
PyTorch neural networks are large. We don’t want them in Dask task graphs, and we only want to move them around once.
from distributed import Client client = Client(n_workers=2, threads_per_worker=2) client
Loading the data on the workers¶
First, we’ll define a couple helpers to load the data and preprocess it for the neural network. We’ll use
dask.delayed here so that the execuation is lazy and happens on the cluster. See the delayed example for more on using
import glob import toolz import dask import dask.array as da import torch from torchvision import transforms from PIL import Image @dask.delayed def load(path, fs=__builtins__): with fs.open(path, 'rb') as f: img = Image.open(f).convert("RGB") return img @dask.delayed def transform(img): trn = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) return trn(img)
objs = [load(x) for x in glob.glob("hymenoptera_data/val/*/*.jpg")]
To load the data from cloud storage, say Amazon S3, you would use
import s3fs fs = s3fs.S3FileSystem(...) objs = [load(x, fs=fs) for x in fs.glob(...)]
The PyTorch model expects tensors of a specific shape, so let’s transform them.
tensors = [transform(x) for x in objs]
And the model expects batches of inputs, so let’s stack a few together.
batches = [dask.delayed(torch.stack)(batch) for batch in toolz.partition_all(10, tensors)] batches[:5]
Finally, we’ll write a small
predict helper to predict the output class (0 or 1).
@dask.delayed def predict(batch, model): with torch.no_grad(): out = model(batch) _, predicted = torch.max(out, 1) predicted = predicted.numpy() return predicted
Moving the model around¶
PyTorch neural networks are large, so we don’t want to repeat it many times in our task graph (once per batch).
import pickle dask.utils.format_bytes(len(pickle.dumps(model)))
Instead, we’ll also wrap the model itself in
dask.delayed. This means the model only shows up once in the Dask graph.
Additionally, since we performed fine-tuning in the above (and that runs on a GPU if its available), we should move the model back to the CPU.
dmodel = dask.delayed(model.cpu()) # ensuring model is on the CPU
Now we’ll use the (delayed)
predict method to get our predictions.
predictions = [predict(batch, dmodel) for batch in batches] dask.visualize(predictions[:2])
The visualization is a bit messy, but the large PyTorch model is the box that’s an ancestor of both
Now, we can do the computation, using the Dask cluster to do all the work. Because the dataset we’re working with is small, it’s safe to just use
dask.compute to bring the results back to the local Client. For a larger dataset you would want to write to disk or cloud storage or continue processing the predictions on the cluster.
predictions = dask.compute(*predictions) predictions
This example showed how to do batch prediction on a set of images using PyTorch and Dask. We were careful to load data remotely on the cluster, and to serialize the large neural network only once.