{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Automate Machine Learning with TPOT\n", "===================================\n", "\n", "This example shows how [TPOT](https://epistasislab.github.io/tpot/) can be used with Dask.\n", "\n", "TPOT is an [automated machine learning](https://en.wikipedia.org/wiki/Automated_machine_learning) library.\n", "It evaluates many scikit-learn pipelines and hyperparameter combinations to find a model that works well for your data. Evaluating all these computations is computationally expensive, but ammenable to parallelism. TPOT can use Dask to distribute these computations on a cluster of machines.\n", "\n", "This notebook can be run interactively on the [dask examples binder](https://github.com/dask/dask-examples).\n", "The following video shows a larger version of this notebook on a cluster." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:36.076555Z", "iopub.status.busy": "2022-07-27T19:23:36.076087Z", "iopub.status.idle": "2022-07-27T19:23:36.139259Z", "shell.execute_reply": "2022-07-27T19:23:36.138318Z" } }, "outputs": [ { "data": { "image/jpeg": "\n", "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from IPython.display import YouTubeVideo\n", "\n", "YouTubeVideo(\"uyx9nBuOYQQ\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:36.142851Z", "iopub.status.busy": "2022-07-27T19:23:36.142514Z", "iopub.status.idle": "2022-07-27T19:23:37.140189Z", "shell.execute_reply": "2022-07-27T19:23:37.138294Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/tpot/builtins/__init__.py:36: UserWarning: Warning: optional dependency `torch` is not available. - skipping import of NN models.\n", " warnings.warn(\"Warning: optional dependency `torch` is not available. - skipping import of NN models.\")\n" ] } ], "source": [ "import tpot\n", "from tpot import TPOTClassifier\n", "from sklearn.datasets import load_digits\n", "from sklearn.model_selection import train_test_split" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Dask\n", "\n", "We first start a Dask client in order to get access to the Dask dashboard, which will provide progress and performance metrics. \n", "\n", "You can view the dashboard by clicking on the dashboard link after you run the cell." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:37.144585Z", "iopub.status.busy": "2022-07-27T19:23:37.143842Z", "iopub.status.idle": "2022-07-27T19:23:40.054987Z", "shell.execute_reply": "2022-07-27T19:23:40.054281Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Client

\n", "

Client-9d3a3f10-0de1-11ed-a59a-000d3a8f7959

\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
Connection method: Cluster objectCluster type: distributed.LocalCluster
\n", " Dashboard: http://127.0.0.1:8787/status\n", "
\n", "\n", " \n", "
\n", "

Cluster Info

\n", "
\n", "
\n", "
\n", "
\n", "

LocalCluster

\n", "

554e7aff

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", "
\n", " Dashboard: http://127.0.0.1:8787/status\n", " \n", " Workers: 4\n", "
\n", " Total threads: 4\n", " \n", " Total memory: 6.78 GiB\n", "
Status: runningUsing processes: True
\n", "\n", "
\n", " \n", "

Scheduler Info

\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", "

Scheduler

\n", "

Scheduler-21db5fb4-9a69-4b11-a57a-880ad23c4052

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Comm: tcp://127.0.0.1:40911\n", " \n", " Workers: 4\n", "
\n", " Dashboard: http://127.0.0.1:8787/status\n", " \n", " Total threads: 4\n", "
\n", " Started: Just now\n", " \n", " Total memory: 6.78 GiB\n", "
\n", "
\n", "
\n", "\n", "
\n", " \n", "

Workers

\n", "
\n", "\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 0

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:34711\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:42307/status\n", " \n", " Memory: 1.70 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:43647\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-m8hopet5\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 1

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:46865\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:38167/status\n", " \n", " Memory: 1.70 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:42469\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-1d4jm_9g\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 2

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:39607\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:46559/status\n", " \n", " Memory: 1.70 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:34247\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-_ojdgmeq\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 3

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:37721\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:37659/status\n", " \n", " Memory: 1.70 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:44933\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-nlky1gwa\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dask.distributed import Client\n", "client = Client(n_workers=4, threads_per_worker=1)\n", "client" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Data\n", "\n", "We'll use the digits dataset.\n", "To ensure the example runs quickly, we'll make the training dataset relatively small." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:40.058341Z", "iopub.status.busy": "2022-07-27T19:23:40.057898Z", "iopub.status.idle": "2022-07-27T19:23:40.110810Z", "shell.execute_reply": "2022-07-27T19:23:40.110180Z" } }, "outputs": [], "source": [ "digits = load_digits()\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " digits.data,\n", " digits.target,\n", " train_size=0.05,\n", " test_size=0.95,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "These are just small, in-memory NumPy arrays. This example is not applicable to larger-than-memory Dask arrays." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Dask\n", "\n", "TPOT follows the scikit-learn API; we specify a `TPOTClassifier` with a few hyperparameters, and then fit it on some data.\n", "By default, TPOT trains on your single machine.\n", "To ensure your cluster is used, specify the `use_dask` keyword." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:40.114474Z", "iopub.status.busy": "2022-07-27T19:23:40.114266Z", "iopub.status.idle": "2022-07-27T19:23:40.118281Z", "shell.execute_reply": "2022-07-27T19:23:40.117512Z" } }, "outputs": [], "source": [ "# scale up: Increase the TPOT parameters like population_size, generations\n", "tp = TPOTClassifier(\n", " generations=2,\n", " population_size=10,\n", " cv=2,\n", " n_jobs=-1,\n", " random_state=0,\n", " verbosity=0,\n", " config_dict=tpot.config.classifier_config_dict_light,\n", " use_dask=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:23:40.121272Z", "iopub.status.busy": "2022-07-27T19:23:40.121061Z", "iopub.status.idle": "2022-07-27T19:23:44.357316Z", "shell.execute_reply": "2022-07-27T19:23:44.356815Z" } }, "outputs": [ { "data": { "text/plain": [ "TPOTClassifier(config_dict={'sklearn.cluster.FeatureAgglomeration': {'affinity': ['euclidean',\n", " 'l1',\n", " 'l2',\n", " 'manhattan',\n", " 'cosine'],\n", " 'linkage': ['ward',\n", " 'complete',\n", " 'average']},\n", " 'sklearn.decomposition.PCA': {'iterated_power': range(1, 11),\n", " 'svd_solver': ['randomized']},\n", " 'sklearn.feature_selection.SelectFwe': {'alpha': array([0. , 0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007...\n", " 'max']},\n", " 'sklearn.preprocessing.RobustScaler': {},\n", " 'sklearn.preprocessing.StandardScaler': {},\n", " 'sklearn.tree.DecisionTreeClassifier': {'criterion': ['gini',\n", " 'entropy'],\n", " 'max_depth': range(1, 11),\n", " 'min_samples_leaf': range(1, 21),\n", " 'min_samples_split': range(2, 21)},\n", " 'tpot.builtins.ZeroCount': {}},\n", " cv=2, generations=2, n_jobs=-1, population_size=10,\n", " random_state=0, use_dask=True)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tp.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learn More\n", "\n", "See the [Dask-ML](http://ml.dask.org/) and [TPOT](https://epistasislab.github.io/tpot/) documenation for more information on using Dask and TPOT." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 4 }