{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Scale Scikit-Learn for Small Data Problems\n", "==========================================\n", "\n", "This example demonstrates how Dask can scale scikit-learn to a cluster of machines for a CPU-bound problem.\n", "We'll fit a large model, a grid-search over many hyper-parameters, on a small dataset.\n", "\n", "This video talks demonstrates the same example on a larger cluster." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:40.954216Z", "iopub.status.busy": "2022-07-27T19:22:40.953750Z", "iopub.status.idle": "2022-07-27T19:22:41.018303Z", "shell.execute_reply": "2022-07-27T19:22:41.017611Z" } }, "outputs": [ { "data": { "image/jpeg": "\n", "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import YouTubeVideo\n", "\n", "YouTubeVideo(\"5Zf6DQaf7jk\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:41.021809Z", "iopub.status.busy": "2022-07-27T19:22:41.021259Z", "iopub.status.idle": "2022-07-27T19:22:44.310730Z", "shell.execute_reply": "2022-07-27T19:22:44.310048Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Client

\n", "

Client-7bcf1e8b-0de1-11ed-a455-000d3a8f7959

\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
Connection method: Cluster objectCluster type: distributed.LocalCluster
\n", " Dashboard: http://127.0.0.1:8787/status\n", "
\n", "\n", " \n", "
\n", "

Cluster Info

\n", "
\n", "
\n", "
\n", "
\n", "

LocalCluster

\n", "

4980102c

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", "
\n", " Dashboard: http://127.0.0.1:8787/status\n", " \n", " Workers: 4\n", "
\n", " Total threads: 4\n", " \n", " Total memory: 7.45 GiB\n", "
Status: runningUsing processes: True
\n", "\n", "
\n", " \n", "

Scheduler Info

\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", "

Scheduler

\n", "

Scheduler-89c11552-5d0d-4279-8917-35605cc96a56

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Comm: tcp://127.0.0.1:38375\n", " \n", " Workers: 4\n", "
\n", " Dashboard: http://127.0.0.1:8787/status\n", " \n", " Total threads: 4\n", "
\n", " Started: Just now\n", " \n", " Total memory: 7.45 GiB\n", "
\n", "
\n", "
\n", "\n", "
\n", " \n", "

Workers

\n", "
\n", "\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 0

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:43429\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:46693/status\n", " \n", " Memory: 1.86 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:35293\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-gt1p0g05\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 1

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:46751\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:35987/status\n", " \n", " Memory: 1.86 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:40509\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-sq4p4isk\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 2

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:35545\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:46021/status\n", " \n", " Memory: 1.86 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:38883\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-2iddzp9o\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 3

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:43855\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:37433/status\n", " \n", " Memory: 1.86 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:32787\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-gfst4qpj\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dask.distributed import Client, progress\n", "client = Client(n_workers=4, threads_per_worker=1, memory_limit='2GB')\n", "client" ] }, { "cell_type": "markdown", "metadata": { "keep_output": true }, "source": [ "## Distributed Training\n", "\n", " \n", "\n", "Scikit-learn uses [joblib](http://joblib.readthedocs.io/) for single-machine parallelism. This lets you train most estimators (anything that accepts an `n_jobs` parameter) using all the cores of your laptop or workstation.\n", "\n", "Alternatively, Scikit-Learn can use Dask for parallelism. This lets you train those estimators using all the cores of your *cluster* without significantly changing your code.\n", "\n", "This is most useful for training large models on medium-sized datasets. You may have a large model when searching over many hyper-parameters, or when using an ensemble method with many individual estimators. For too small datasets, training times will typically be small enough that cluster-wide parallelism isn't helpful. For too large datasets (larger than a single machine's memory), the scikit-learn estimators may not be able to cope (though Dask-ML provides other ways for working with larger than memory datasets)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create Scikit-Learn Pipeline" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:44.314191Z", "iopub.status.busy": "2022-07-27T19:22:44.313670Z", "iopub.status.idle": "2022-07-27T19:22:44.773748Z", "shell.execute_reply": "2022-07-27T19:22:44.773086Z" } }, "outputs": [], "source": [ "from pprint import pprint\n", "from time import time\n", "import logging\n", "\n", "from sklearn.datasets import fetch_20newsgroups\n", "from sklearn.feature_extraction.text import HashingVectorizer\n", "from sklearn.feature_extraction.text import TfidfTransformer\n", "from sklearn.linear_model import SGDClassifier\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.pipeline import Pipeline" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:44.777601Z", "iopub.status.busy": "2022-07-27T19:22:44.776949Z", "iopub.status.idle": "2022-07-27T19:22:53.129035Z", "shell.execute_reply": "2022-07-27T19:22:53.128390Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading 20 newsgroups dataset for categories:\n", "['alt.atheism', 'talk.religion.misc']\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "857 documents\n", "2 categories\n", "\n" ] } ], "source": [ "# Scale Up: set categories=None to use all the categories\n", "categories = [\n", " 'alt.atheism',\n", " 'talk.religion.misc',\n", "]\n", "\n", "print(\"Loading 20 newsgroups dataset for categories:\")\n", "print(categories)\n", "\n", "data = fetch_20newsgroups(subset='train', categories=categories)\n", "print(\"%d documents\" % len(data.filenames))\n", "print(\"%d categories\" % len(data.target_names))\n", "print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll define a small pipeline that combines text feature extraction with a simple classifier." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:53.132634Z", "iopub.status.busy": "2022-07-27T19:22:53.132192Z", "iopub.status.idle": "2022-07-27T19:22:53.135995Z", "shell.execute_reply": "2022-07-27T19:22:53.135342Z" } }, "outputs": [], "source": [ "pipeline = Pipeline([\n", " ('vect', HashingVectorizer()),\n", " ('tfidf', TfidfTransformer()),\n", " ('clf', SGDClassifier(max_iter=1000)),\n", "])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Grid for Parameter Search" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Grid search over some parameters." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:53.140031Z", "iopub.status.busy": "2022-07-27T19:22:53.139614Z", "iopub.status.idle": "2022-07-27T19:22:53.143379Z", "shell.execute_reply": "2022-07-27T19:22:53.142908Z" } }, "outputs": [], "source": [ "parameters = {\n", " 'tfidf__use_idf': (True, False),\n", " 'tfidf__norm': ('l1', 'l2'),\n", " 'clf__alpha': (0.00001, 0.000001),\n", " # 'clf__penalty': ('l2', 'elasticnet'),\n", " # 'clf__n_iter': (10, 50, 80),\n", "}" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:53.147810Z", "iopub.status.busy": "2022-07-27T19:22:53.145712Z", "iopub.status.idle": "2022-07-27T19:22:53.151018Z", "shell.execute_reply": "2022-07-27T19:22:53.150392Z" } }, "outputs": [], "source": [ "grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1, cv=3, refit=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To fit this normally, we would write\n", "\n", "\n", "```python\n", "grid_search.fit(data.data, data.target)\n", "```\n", "\n", "That would use the default joblib backend (multiple processes) for parallelism.\n", "To use the Dask distributed backend, which will use a cluster of machines to train the model, perform the fit in a `parallel_backend` context." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:53.154020Z", "iopub.status.busy": "2022-07-27T19:22:53.153618Z", "iopub.status.idle": "2022-07-27T19:23:00.999089Z", "shell.execute_reply": "2022-07-27T19:23:00.998395Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Fitting 3 folds for each of 8 candidates, totalling 24 fits\n" ] } ], "source": [ "import joblib\n", "\n", "with joblib.parallel_backend('dask'):\n", " grid_search.fit(data.data, data.target)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you had your distributed dashboard open during that fit, you'll notice that each worker performs some of the fit tasks." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Parallel, Distributed Prediction\n", "\n", "Sometimes, you're train on a small dataset, but need to predict for a much larger batch of data.\n", "In this case, you'd like your estimator to handle NumPy arrays and pandas DataFrames for training, and dask arrays or DataFrames for prediction. [`dask_ml.wrappers.ParallelPostFit`](http://ml.dask.org/modules/generated/dask_ml.wrappers.ParallelPostFit.html#dask_ml.wrappers.ParallelPostFit) provides exactly that. It's a meta-estimator. It does nothing during training; the underlying estimator (probably a scikit-learn estimator) will probably be in-memory on a single machine. But tasks like `predict`, `score`, etc. are parallelized and distributed.\n", "\n", "Most of the time, using `ParallelPostFit` is as simple as wrapping the original estimator.\n", "When used inside a GridSearch, you'll need to update the keys of the parameters, just like with any meta-estimator.\n", "The only complication comes when using `ParallelPostFit` with another meta-estimator like `GridSearchCV`. In this case, you'll need to prefix your parameter names with `estimator__`." ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:01.002484Z", "iopub.status.busy": "2022-07-27T19:23:01.002047Z", "iopub.status.idle": "2022-07-27T19:23:01.227161Z", "shell.execute_reply": "2022-07-27T19:23:01.226471Z" } }, "outputs": [], "source": [ "from sklearn.datasets import load_digits\n", "from sklearn.svm import SVC\n", "from dask_ml.wrappers import ParallelPostFit" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll load the small NumPy arrays for training." ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:01.230642Z", "iopub.status.busy": "2022-07-27T19:23:01.230198Z", "iopub.status.idle": "2022-07-27T19:23:01.289344Z", "shell.execute_reply": "2022-07-27T19:23:01.288604Z" } }, "outputs": [ { "data": { "text/plain": [ "(1797, 64)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "X, y = load_digits(return_X_y=True)\n", "X.shape" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:01.292868Z", "iopub.status.busy": "2022-07-27T19:23:01.292448Z", "iopub.status.idle": "2022-07-27T19:23:01.297885Z", "shell.execute_reply": "2022-07-27T19:23:01.297267Z" } }, "outputs": [], "source": [ "svc = ParallelPostFit(SVC(random_state=0, gamma='scale'))\n", "\n", "param_grid = {\n", " # use estimator__param instead of param\n", " 'estimator__C': [0.01, 1.0, 10],\n", "}\n", "\n", "grid_search = GridSearchCV(svc, param_grid, cv=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And fit as usual." ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:01.300843Z", "iopub.status.busy": "2022-07-27T19:23:01.300263Z", "iopub.status.idle": "2022-07-27T19:23:02.566998Z", "shell.execute_reply": "2022-07-27T19:23:02.565866Z" } }, "outputs": [ { "data": { "text/plain": [ "GridSearchCV(cv=3, estimator=ParallelPostFit(estimator=SVC(random_state=0)),\n", " param_grid={'estimator__C': [0.01, 1.0, 10]})" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grid_search.fit(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We'll simulate a large dask array by replicating the training data a few times.\n", "In reality, you would load this from your file system." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:02.570655Z", "iopub.status.busy": "2022-07-27T19:23:02.570436Z", "iopub.status.idle": "2022-07-27T19:23:02.575081Z", "shell.execute_reply": "2022-07-27T19:23:02.574223Z" } }, "outputs": [], "source": [ "import dask.array as da" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:02.578295Z", "iopub.status.busy": "2022-07-27T19:23:02.577878Z", "iopub.status.idle": "2022-07-27T19:23:02.608811Z", "shell.execute_reply": "2022-07-27T19:23:02.608081Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Array Chunk
Bytes 8.77 MiB 898.50 kiB
Shape (17970, 64) (1797, 64)
Count 11 Tasks 10 Chunks
Type float64 numpy.ndarray
\n", "
\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " 64\n", " 17970\n", "\n", "
" ], "text/plain": [ "dask.array" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "big_X = da.concatenate([\n", " da.from_array(X, chunks=X.shape)\n", " for _ in range(10)\n", "])\n", "big_X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Operations like `predict`, or `predict_proba` return dask, rather than NumPy arrays.\n", "When you compute, the work will be done in parallel, out of core or distributed on the cluster." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:02.613227Z", "iopub.status.busy": "2022-07-27T19:23:02.612777Z", "iopub.status.idle": "2022-07-27T19:23:02.630320Z", "shell.execute_reply": "2022-07-27T19:23:02.629613Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Array Chunk
Bytes 140.39 kiB 14.04 kiB
Shape (17970,) (1797,)
Count 21 Tasks 10 Chunks
Type int64 numpy.ndarray
\n", "
\n", " \n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " 17970\n", " 1\n", "\n", "
" ], "text/plain": [ "dask.array<_predict, shape=(17970,), dtype=int64, chunksize=(1797,), chunktype=numpy.ndarray>" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicted = grid_search.predict(big_X)\n", "predicted" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "At this point predicted could be written to disk, or aggregated before returning to the client." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 4 }