{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Batch Prediction with PyTorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example follows Torch's [transfer learning tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html). We will\n", "\n", "1. Finetune a pretrained convolutional neural network on a specific task (ants vs. bees).\n", "2. Use a Dask cluster for batch prediction with that model.\n", "\n", "_Note:_ The dependencies for this example are not installed by default in the Binder environment. You'll need to execute\n", "\n", "```\n", "!conda install torchvision pytorch-cpu\n", "```\n", "\n", "in a cell to install the necessary packages.\n", "\n", "The primary focus is using a Dask cluster for batch prediction." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Download the data\n", "\n", "The PyTorch documentation hosts a small set of data. We'll download and extract it locally." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import urllib.request\n", "import zipfile" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "filename, _ = urllib.request.urlretrieve(\"https://download.pytorch.org/tutorial/hymenoptera_data.zip\", \"data.zip\")\n", "zipfile.ZipFile(filename).extractall()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The directory looks like\n", "\n", "```\n", "hymenoptera_data/\n", " train/\n", " ants/\n", " 0013035.jpg\n", " ...\n", " 1030023514_aad5c608f9.jpg\n", " bees/\n", " 1092977343_cb42b38d62.jpg\n", " ...\n", " 2486729079_62df0920be.jpg\n", " \n", " train/\n", " ants/\n", " 0013025.jpg\n", " ...\n", " 1030023514_aad5c606d9.jpg\n", " bees/\n", " 1092977343_cb42b38e62.jpg\n", " ...\n", " 2486729079_62df0921be.jpg\n", " \n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Following the [tutorial](https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html), we'll finetune the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torchvision\n", "from tutorial_helper import (imshow, train_model, visualize_model,\n", " dataloaders, class_names, finetune_model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Finetune the model\n", "\n", "Our base model is resnet18. It predicts for 1,000 categories, while ours just predicts 2 (ants or bees). To make this model train quickly on examples.dask.org, we'll only use a couple of epochs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import dask" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "model = finetune_model()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Things seem OK on a few random images:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "visualize_model(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Batch Prediction with Dask\n", "\n", "Now for the main topic: using a pretrained model for batch prediction on a Dask cluster.\n", "There are two main complications, that both deal with minimizing the amount of data\n", "moved around:\n", "\n", "1. **Loading the data on the workers.**. We'll use `dask.delayed` to load the data on\n", " the workers, rather than loading it on the client and sending it to the workers.\n", "2. **PyTorch neural networks are large.** We don't want them in Dask task graphs, and we\n", " only want to move them around once." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from distributed import Client\n", "\n", "client = Client(n_workers=2, threads_per_worker=2)\n", "client" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loading the data on the workers\n", "\n", "First, we'll define a couple helpers to load the data and preprocess it for the neural network.\n", "We'll use `dask.delayed` here so that the execuation is lazy and happens on the cluster.\n", "See [the delayed example](../delayed.ipynb) for more on using `dask.delayed`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import glob\n", "import toolz\n", "import dask\n", "import dask.array as da\n", "import torch\n", "from torchvision import transforms\n", "from PIL import Image\n", "\n", "\n", "@dask.delayed\n", "def load(path, fs=__builtins__):\n", " with fs.open(path, 'rb') as f:\n", " img = Image.open(f).convert(\"RGB\")\n", " return img\n", "\n", "\n", "@dask.delayed\n", "def transform(img):\n", " trn = transforms.Compose([\n", " transforms.Resize(256),\n", " transforms.CenterCrop(224),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n", " ])\n", " return trn(img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "objs = [load(x) for x in glob.glob(\"hymenoptera_data/val/*/*.jpg\")]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To load the data from cloud storage, say Amazon S3, you would use\n", "\n", "```python\n", "import s3fs\n", "\n", "fs = s3fs.S3FileSystem(...)\n", "objs = [load(x, fs=fs) for x in fs.glob(...)]\n", "```\n", "\n", "The PyTorch model expects tensors of a specific shape, so let's\n", "transform them." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tensors = [transform(x) for x in objs]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And the model expects batches of inputs, so let's stack a few together." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "batches = [dask.delayed(torch.stack)(batch)\n", " for batch in toolz.partition_all(10, tensors)]\n", "batches[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, we'll write a small `predict` helper to predict the output class (0 or 1)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@dask.delayed\n", "def predict(batch, model):\n", " with torch.no_grad():\n", " out = model(batch)\n", " _, predicted = torch.max(out, 1)\n", " predicted = predicted.numpy()\n", " return predicted" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Moving the model around\n", "\n", "PyTorch neural networks are large, so we don't want to repeat it many times in our task graph (once per batch)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pickle\n", "\n", "dask.utils.format_bytes(len(pickle.dumps(model)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instead, we'll also wrap the model itself in `dask.delayed`. This means the model only shows up once in the Dask graph.\n", "\n", "Additionally, since we performed fine-tuning in the above (and that runs on a GPU if its available), we should move the model back to the CPU." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dmodel = dask.delayed(model.cpu()) # ensuring model is on the CPU" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we'll use the (delayed) `predict` method to get our predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictions = [predict(batch, dmodel) for batch in batches]\n", "dask.visualize(predictions[:2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The visualization is a bit messy, but the large PyTorch model is the box that's an ancestor of both `predict` tasks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can do the computation, using the Dask cluster to do all the work. Because the dataset we're working with is small, it's safe to just use `dask.compute` to bring the results back to the local Client. For a larger dataset you would want to write to disk or cloud storage or continue processing the predictions on the cluster." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictions = dask.compute(*predictions)\n", "predictions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary\n", "\n", "This example showed how to do batch prediction on a set of images using PyTorch and Dask.\n", "We were careful to load data remotely on the cluster, and to serialize the large neural network only once." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" }, "nbsphinx": { "execute": "never" } }, "nbformat": 4, "nbformat_minor": 4 }