Live Notebook

You can run this notebook in a live session Binder or view it on Github.

Automate Machine Learning with TPOT

This example shows how TPOT can be used with Dask.

TPOT is an automated machine learning library. It evaluates many scikit-learn pipelines and hyperparameter combinations to find a model that works well for your data. Evaluating all these computations is computationally expensive, but ammenable to parallelism. TPOT can use Dask to distribute these computations on a cluster of machines.

This notebook can be run interactively on the dask examples binder. The following video shows a larger version of this notebook on a cluster.

from IPython.display import HTML

HTML('<div style="position:relative;height:0;padding-bottom:56.25%"><iframe src="" width="640" height="360" frameborder="0" allow="autoplay; encrypted-media" style="position:absolute;width:100%;height:100%;left:0" allowfullscreen></iframe></div>')
!pip install tpot
Collecting tpot
  Downloading TPOT-0.11.5-py3-none-any.whl (82 kB)

Requirement already satisfied: scipy>=1.3.1 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from tpot) (1.4.1)
Requirement already satisfied: pandas>=0.24.2 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from tpot) (0.25.3)
Collecting stopit>=1.1.1
  Downloading stopit-1.1.2.tar.gz (18 kB)
Collecting update-checker>=0.16
  Downloading update_checker-0.17-py2.py3-none-any.whl (7.0 kB)
Requirement already satisfied: scikit-learn>=0.22.0 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from tpot) (0.22.2.post1)
Requirement already satisfied: tqdm>=4.36.1 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from tpot) (4.46.1)
Requirement already satisfied: numpy>=1.16.3 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from tpot) (1.17.5)
Requirement already satisfied: joblib>=0.13.2 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from tpot) (0.15.1)
Collecting deap>=1.2
  Downloading deap-1.3.1-cp37-cp37m-manylinux2010_x86_64.whl (157 kB)

Requirement already satisfied: pytz>=2017.2 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from pandas>=0.24.2->tpot) (2020.1)
Requirement already satisfied: python-dateutil>=2.6.1 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from pandas>=0.24.2->tpot) (2.8.1)
Requirement already satisfied: requests>=2.3.0 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from update-checker>=0.16->tpot) (2.24.0)
Requirement already satisfied: six>=1.5 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from python-dateutil>=2.6.1->pandas>=0.24.2->tpot) (1.15.0)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (1.25.9)
Requirement already satisfied: certifi>=2017.4.17 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (2020.6.20)
Requirement already satisfied: chardet<4,>=3.0.2 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /home/travis/miniconda/envs/test/lib/python3.7/site-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (2.9)
Building wheels for collected packages: stopit
  Building wheel for stopit ( ... - done
  Created wheel for stopit: filename=stopit-1.1.2-py3-none-any.whl size=11956 sha256=f50281fa190845cfd4d55982a15706e99105d11e03c384bfb6dab4a8dd66b88f
  Stored in directory: /home/travis/.cache/pip/wheels/e2/d2/79/eaf81edb391e27c87f51b8ef901ecc85a5363dc96b8b8d71e3
Successfully built stopit
Installing collected packages: stopit, update-checker, deap, tpot
Successfully installed deap-1.3.1 stopit-1.1.2 tpot-0.11.5 update-checker-0.17
import tpot
from tpot import TPOTClassifier
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
/home/travis/miniconda/envs/test/lib/python3.7/site-packages/tpot/builtins/ UserWarning: Warning: optional dependency `torch` is not available. - skipping import of NN models.
  warnings.warn("Warning: optional dependency `torch` is not available. - skipping import of NN models.")

Setup Dask

We first start a Dask client in order to get access to the Dask dashboard, which will provide progress and performance metrics.

You can view the dashboard by clicking on the dashboard link after you run the cell.

from dask.distributed import Client
client = Client(n_workers=4, threads_per_worker=1)



  • Workers: 4
  • Cores: 4
  • Memory: 8.36 GB

Create Data

We’ll use the digits dataset. To ensure the example runs quickly, we’ll make the training dataset relatively small.

digits = load_digits()

X_train, X_test, y_train, y_test = train_test_split(,,

These are just small, in-memory NumPy arrays. This example is not applicable to larger-than-memory Dask arrays.

Using Dask

TPOT follows the scikit-learn API; we specify a TPOTClassifier with a few hyperparameters, and then fit it on some data. By default, TPOT trains on your single machine. To ensure your cluster is used, specify the use_dask keyword.

# scale up: Increase the TPOT parameters like population_size, generations
tp = TPOTClassifier(
[7]:, y_train)
TPOTClassifier(config_dict={'sklearn.cluster.FeatureAgglomeration': {'affinity': ['euclidean',
                                                                     'linkage': ['ward',
                            'sklearn.decomposition.PCA': {'iterated_power': range(1, 11),
                                                          'svd_solver': ['randomized']},
                            'sklearn.feature_selection.SelectFwe': {'alpha': array([0.   , 0.001, 0.002, 0.003, 0.004, 0.005, 0.006, 0.007...
               crossover_rate=0.1, cv=2, disable_update_check=False,
               early_stop=None, generations=2,
               log_file=<ipykernel.iostream.OutStream object at 0x7fbe5e7e3e10>,
               max_eval_time_mins=5, max_time_mins=None, memory=None,
               mutation_rate=0.9, n_jobs=-1, offspring_size=None,
               periodic_checkpoint_folder=None, population_size=10,
               random_state=0, scoring=None, subsample=1.0, template=None,
               use_dask=True, verbosity=0, warm_start=False)

Learn More

See the Dask-ML and TPOT documenation for more information on using Dask and TPOT.