{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Hyperparameter optimization with Dask\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Every machine learning model has some values that are specified before training begins. These values help adapt the model to the data but must be given before any training data is seen. For example, this might be `penalty` or `C` in Scikit-learn's [LogisiticRegression]. These values that come before any training data and are called \"hyperparameters\". Typical usage looks something like:\n", "\n", "``` python\n", "from sklearn.linear_model import LogisiticRegression\n", "from sklearn.datasets import make_classification\n", "\n", "X, y = make_classification()\n", "est = LogisiticRegression(C=10, penalty=\"l2\")\n", "est.fit(X, y)\n", "```\n", "\n", "These hyperparameters influence the quality of the prediction. For example, if `C` is too small in the example above, the output of the estimator will not fit the data well.\n", "\n", "Determining the values of these hyperparameters is difficult. In fact, Scikit-learn has an entire documentation page on finding the best values: https://scikit-learn.org/stable/modules/grid_search.html\n", "\n", "[LogisiticRegression]:https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dask enables some new techniques and opportunities for hyperparameter optimization. One of these opportunities involves stopping training early to limit computation. Naturally, this requires some way to stop and restart training (`partial_fit` or `warm_start` in Scikit-learn parlance).\n", "\n", "This is especially useful when the search is complex and has many search parameters. Good examples are most deep learning models, which has specialized algorithms for handling many data but have difficulty providing basic hyperparameters (e.g., \"learning rate\", \"momentum\" or \"weight decay\").\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**This notebook will walk through**\n", "\n", "* setting up a realistic example\n", "* how to use `HyperbandSearchCV`, including\n", " * understanding the input parameters to `HyperbandSearchCV`\n", " * running the hyperparameter optimization\n", " * how to access informantion from `HyperbandSearchCV`\n", " \n", "This notebook will specifically *not* show a performance comparison motivating `HyperbandSearchCV` use. `HyperbandSearchCV` finds high scores with minimal training; however, this is a tutorial on how to *use* it. All performance comparisons are relegated to section [*Learn more*](#Learn-more)." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:09.460955Z", "iopub.status.busy": "2021-01-14T10:49:09.460398Z", "iopub.status.idle": "2021-01-14T10:49:09.865307Z", "shell.execute_reply": "2021-01-14T10:49:09.865688Z" } }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Dask" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:09.868489Z", "iopub.status.busy": "2021-01-14T10:49:09.867745Z", "iopub.status.idle": "2021-01-14T10:49:11.123662Z", "shell.execute_reply": "2021-01-14T10:49:11.124129Z" } }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "
\n", "

Client

\n", "\n", "
\n", "

Cluster

\n", "
    \n", "
  • Workers: 1
  • \n", "
  • Cores: 4
  • \n", "
  • Memory: 2.00 GB
  • \n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from distributed import Client\n", "client = Client(processes=False, threads_per_worker=4,\n", " n_workers=1, memory_limit='2GB')\n", "client" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create Data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:11.128786Z", "iopub.status.busy": "2021-01-14T10:49:11.126507Z", "iopub.status.idle": "2021-01-14T10:49:12.335603Z", "shell.execute_reply": "2021-01-14T10:49:12.335138Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.datasets import make_circles\n", "import numpy as np\n", "import pandas as pd\n", "\n", "X, y = make_circles(n_samples=30_000, random_state=0, noise=0.09)\n", "\n", "pd.DataFrame({0: X[:, 0], 1: X[:, 1], \"class\": y}).sample(4_000).plot.scatter(\n", " x=0, y=1, alpha=0.2, c=\"class\", cmap=\"bwr\"\n", ");" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Add random dimensions" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.340434Z", "iopub.status.busy": "2021-01-14T10:49:12.340020Z", "iopub.status.idle": "2021-01-14T10:49:12.344320Z", "shell.execute_reply": "2021-01-14T10:49:12.344649Z" } }, "outputs": [ { "data": { "text/plain": [ "(30000, 6)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.utils import check_random_state\n", "\n", "rng = check_random_state(42)\n", "random_feats = rng.uniform(-1, 1, size=(X.shape[0], 4))\n", "X = np.hstack((X, random_feats))\n", "X.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Split and scale data" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.349486Z", "iopub.status.busy": "2021-01-14T10:49:12.349077Z", "iopub.status.idle": "2021-01-14T10:49:12.366499Z", "shell.execute_reply": "2021-01-14T10:49:12.367447Z" } }, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=5_000, random_state=42)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.370869Z", "iopub.status.busy": "2021-01-14T10:49:12.370289Z", "iopub.status.idle": "2021-01-14T10:49:12.377627Z", "shell.execute_reply": "2021-01-14T10:49:12.377937Z" } }, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "from sklearn.model_selection import train_test_split\n", "scaler = StandardScaler().fit(X_train)\n", "\n", "X_train = scaler.transform(X_train)\n", "X_test = scaler.transform(X_test)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.384717Z", "iopub.status.busy": "2021-01-14T10:49:12.384146Z", "iopub.status.idle": "2021-01-14T10:49:12.386105Z", "shell.execute_reply": "2021-01-14T10:49:12.385126Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "dataset = train\n", "shape = (25000, 6)\n", "bytes = 1.20 MB\n", "--------------------\n", "dataset = test\n", "shape = (5000, 6)\n", "bytes = 240.00 kB\n", "--------------------\n" ] } ], "source": [ "from dask.utils import format_bytes\n", "\n", "for name, X in [(\"train\", X_train), (\"test\", X_test)]:\n", " print(\"dataset =\", name)\n", " print(\"shape =\", X.shape)\n", " print(\"bytes =\", format_bytes(X.nbytes))\n", " print(\"-\" * 20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have our train and test sets." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create model and search space" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's use Scikit-learn's MLPClassifier as our model (for convenience). Let's use this model with 24 neurons and tune some of the other basic hyperparameters.\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.389423Z", "iopub.status.busy": "2021-01-14T10:49:12.388991Z", "iopub.status.idle": "2021-01-14T10:49:12.395883Z", "shell.execute_reply": "2021-01-14T10:49:12.395351Z" } }, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.neural_network import MLPClassifier\n", "\n", "model = MLPClassifier()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Deep learning libraries can be used as well. In particular, [PyTorch]'s Scikit-Learn wrapper [Skorch] works well with `HyperbandSearchCV`.\n", "\n", "[PyTorch]:https://pytorch.org/\n", "[Skorch]:https://skorch.readthedocs.io/en/stable/" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.404939Z", "iopub.status.busy": "2021-01-14T10:49:12.403829Z", "iopub.status.idle": "2021-01-14T10:49:12.405362Z", "shell.execute_reply": "2021-01-14T10:49:12.404178Z" } }, "outputs": [], "source": [ "params = {\n", " \"hidden_layer_sizes\": [\n", " (24, ),\n", " (12, 12),\n", " (6, 6, 6, 6),\n", " (4, 4, 4, 4, 4, 4),\n", " (12, 6, 3, 3),\n", " ],\n", " \"activation\": [\"relu\", \"logistic\", \"tanh\"],\n", " \"alpha\": np.logspace(-6, -3, num=1000), # cnts\n", " \"batch_size\": [16, 32, 64, 128, 256, 512],\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Hyperparameter optimization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`HyperbandSearchCV` is Dask-ML's meta-estimator to find the best hyperparameters. It can be used as an alternative to `RandomizedSearchCV` to find similar hyper-parameters in less time by not wasting time on hyper-parameters that are not promising. Specifically, it is almost guaranteed that it will find high performing models with minimal training.\n", "\n", "This section will focus on\n", "\n", "1. Understanding the input parameters to `HyperbandSearchCV`\n", "2. Using `HyperbandSearchCV` to find the best hyperparameters\n", "3. Seeing other use cases of `HyperbandSearchCV`" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.413344Z", "iopub.status.busy": "2021-01-14T10:49:12.412862Z", "iopub.status.idle": "2021-01-14T10:49:12.679320Z", "shell.execute_reply": "2021-01-14T10:49:12.679764Z" } }, "outputs": [], "source": [ "from dask_ml.model_selection import HyperbandSearchCV" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Determining input parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A rule-of-thumb to determine `HyperbandSearchCV`'s input parameters requires knowing:\n", "\n", "1. the number of examples the longest trained model will see\n", "2. the number of hyperparameters to evaluate" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's write down what these should be for this example:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.688337Z", "iopub.status.busy": "2021-01-14T10:49:12.682037Z", "iopub.status.idle": "2021-01-14T10:49:12.691415Z", "shell.execute_reply": "2021-01-14T10:49:12.692184Z" } }, "outputs": [], "source": [ "# For quick response\n", "n_examples = 4 * len(X_train)\n", "n_params = 8\n", "\n", "# In practice, HyperbandSearchCV is most useful for longer searches\n", "# n_examples = 15 * len(X_train)\n", "# n_params = 15" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this, models that are trained the longest will see `n_examples` examples. This is how much data is required, normally set be the problem difficulty. Simple problems may only need 10 passes through the dataset; more complex problems may need 100 passes through the dataset.\n", "\n", "There will be `n_params` parameters sampled so `n_params` models will be evaluated. Models with low scores will be terminated before they see `n_examples` examples. This helps perserve computation." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How can we use these values to determine the inputs for `HyperbandSearchCV`?" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.695618Z", "iopub.status.busy": "2021-01-14T10:49:12.694425Z", "iopub.status.idle": "2021-01-14T10:49:12.702942Z", "shell.execute_reply": "2021-01-14T10:49:12.703680Z" } }, "outputs": [ { "data": { "text/plain": [ "(8, 12500)" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "max_iter = n_params # number of times partial_fit will be called\n", "chunks = n_examples // n_params # number of examples each call sees\n", "\n", "max_iter, chunks" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This means that the longest trained estimator will see about `n_examples` examples (specifically `n_params * (n_examples // n_params`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Applying input parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's create a Dask array with this chunk size:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.707285Z", "iopub.status.busy": "2021-01-14T10:49:12.706237Z", "iopub.status.idle": "2021-01-14T10:49:12.721967Z", "shell.execute_reply": "2021-01-14T10:49:12.722803Z" } }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Array Chunk
Bytes 1.20 MB 600.00 kB
Shape (25000, 6) (12500, 6)
Count 3 Tasks 2 Chunks
Type float64 numpy.ndarray
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " 6\n", " 25000\n", "\n", "
" ], "text/plain": [ "dask.array" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import dask.array as da\n", "X_train2 = da.from_array(X_train, chunks=chunks)\n", "y_train2 = da.from_array(y_train, chunks=chunks)\n", "X_train2" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each `partial_fit` call will receive one chunk." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That means the number of exmaples in each chunk should be (about) the same, and `n_examples` and `n_params` should be chosen to make that happen. (e.g., with 100 examples, shoot for chunks with `(33, 33, 34)` examples not `(48, 48, 4)` examples)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let's use `max_iter` to create our `HyperbandSearchCV` object:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.725457Z", "iopub.status.busy": "2021-01-14T10:49:12.724538Z", "iopub.status.idle": "2021-01-14T10:49:12.728570Z", "shell.execute_reply": "2021-01-14T10:49:12.729312Z" } }, "outputs": [], "source": [ "search = HyperbandSearchCV(\n", " model,\n", " params,\n", " max_iter=max_iter,\n", " patience=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How much computation will be performed?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It isn't clear how to determine how much computation is done from `max_iter` and `chunks`. Luckily, `HyperbandSearchCV` has a `metadata` attribute to determine this beforehand:" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.733058Z", "iopub.status.busy": "2021-01-14T10:49:12.731556Z", "iopub.status.idle": "2021-01-14T10:49:12.742564Z", "shell.execute_reply": "2021-01-14T10:49:12.743667Z" } }, "outputs": [ { "data": { "text/plain": [ "26" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search.metadata[\"partial_fit_calls\"]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This shows how many `partial_fit` calls will be performed in the computation. `metadata` also includes information on the number of models created." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "So far, all that's been done is getting the search ready for computation (and seeing how much computation will be performed). So far, all the computation has been quick and easy." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Performing the computation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's do the model selection search and find the best hyperparameters. This is the real core of this notebook. This computation will be take place on all the hardware Dask has available.\n" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:12.747231Z", "iopub.status.busy": "2021-01-14T10:49:12.746059Z", "iopub.status.idle": "2021-01-14T10:49:17.219687Z", "shell.execute_reply": "2021-01-14T10:49:17.218223Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.57 s, sys: 156 ms, total: 2.72 s\n", "Wall time: 4.06 s\n" ] }, { "data": { "text/plain": [ "HyperbandSearchCV(estimator=MLPClassifier(), max_iter=8,\n", " parameters={'activation': ['relu', 'logistic', 'tanh'],\n", " 'alpha': array([1.00000000e-06, 1.00693863e-06, 1.01392541e-06, 1.02096066e-06,\n", " 1.02804473e-06, 1.03517796e-06, 1.04236067e-06, 1.04959323e-06,\n", " 1.05687597e-06, 1.06420924e-06, 1.07159340e-06, 1.07902879e-06,\n", " 1.08651577e-06, 1.09405471e-06, 1.10164595e-06, 1.1...\n", " 9.01477631e-04, 9.07732653e-04, 9.14031075e-04, 9.20373200e-04,\n", " 9.26759330e-04, 9.33189772e-04, 9.39664831e-04, 9.46184819e-04,\n", " 9.52750047e-04, 9.59360829e-04, 9.66017480e-04, 9.72720319e-04,\n", " 9.79469667e-04, 9.86265846e-04, 9.93109181e-04, 1.00000000e-03]),\n", " 'batch_size': [16, 32, 64, 128, 256, 512],\n", " 'hidden_layer_sizes': [(24,), (12, 12),\n", " (6, 6, 6, 6),\n", " (4, 4, 4, 4, 4, 4),\n", " (12, 6, 3, 3)]},\n", " patience=True)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%time\n", "search.fit(X_train2, y_train2, classes=[0, 1, 2, 3])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The dashboard will be active while this is running. It will show which workers are running `partial_fit` and `score` calls.\n", "This takes about 10 seconds." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Integration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`HyperbandSearchCV` follows the Scikit-learn API and mirrors Scikit-learn's `RandomizedSearchCV`. This means that it \"just works\". All the Scikit-learn attributes and methods are available:" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:17.224233Z", "iopub.status.busy": "2021-01-14T10:49:17.223823Z", "iopub.status.idle": "2021-01-14T10:49:17.233654Z", "shell.execute_reply": "2021-01-14T10:49:17.232774Z" } }, "outputs": [ { "data": { "text/plain": [ "0.5654" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search.best_score_" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:17.238179Z", "iopub.status.busy": "2021-01-14T10:49:17.237781Z", "iopub.status.idle": "2021-01-14T10:49:17.253727Z", "shell.execute_reply": "2021-01-14T10:49:17.254325Z" } }, "outputs": [ { "data": { "text/plain": [ "MLPClassifier(activation='tanh', alpha=2.010496416260497e-06, batch_size=32,\n", " hidden_layer_sizes=(4, 4, 4, 4, 4, 4))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search.best_estimator_" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:17.269080Z", "iopub.status.busy": "2021-01-14T10:49:17.258800Z", "iopub.status.idle": "2021-01-14T10:49:17.337714Z", "shell.execute_reply": "2021-01-14T10:49:17.338037Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
param_alphapartial_fit_callsmean_score_timemodel_idparam_hidden_layer_sizesparam_activationstd_score_timeparam_batch_sizebracketparamstest_scorerank_test_scoremean_partial_fit_timestd_partial_fit_time
00.00001720.016706bracket=1-0(6, 6, 6, 6)logistic0.0103321281{'hidden_layer_sizes': (6, 6, 6, 6), 'batch_si...0.489430.3387110.015633
10.00003120.034787bracket=1-1(12, 12)logistic0.0007371281{'hidden_layer_sizes': (12, 12), 'batch_size':...0.510610.2508740.049119
20.00000120.016971bracket=1-2(4, 4, 4, 4, 4, 4)relu0.0002055121{'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba...0.510610.1534890.008385
30.00000280.017569bracket=0-0[4, 4, 4, 4, 4, 4]tanh0.010754320{'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba...0.565410.4975210.316641
40.00081330.033582bracket=0-1[4, 4, 4, 4, 4, 4]tanh0.0170945120{'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba...0.000020.1294650.062709
\n", "
" ], "text/plain": [ " param_alpha partial_fit_calls mean_score_time model_id \\\n", "0 0.000017 2 0.016706 bracket=1-0 \n", "1 0.000031 2 0.034787 bracket=1-1 \n", "2 0.000001 2 0.016971 bracket=1-2 \n", "3 0.000002 8 0.017569 bracket=0-0 \n", "4 0.000813 3 0.033582 bracket=0-1 \n", "\n", " param_hidden_layer_sizes param_activation std_score_time param_batch_size \\\n", "0 (6, 6, 6, 6) logistic 0.010332 128 \n", "1 (12, 12) logistic 0.000737 128 \n", "2 (4, 4, 4, 4, 4, 4) relu 0.000205 512 \n", "3 [4, 4, 4, 4, 4, 4] tanh 0.010754 32 \n", "4 [4, 4, 4, 4, 4, 4] tanh 0.017094 512 \n", "\n", " bracket params test_score \\\n", "0 1 {'hidden_layer_sizes': (6, 6, 6, 6), 'batch_si... 0.4894 \n", "1 1 {'hidden_layer_sizes': (12, 12), 'batch_size':... 0.5106 \n", "2 1 {'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba... 0.5106 \n", "3 0 {'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba... 0.5654 \n", "4 0 {'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba... 0.0000 \n", "\n", " rank_test_score mean_partial_fit_time std_partial_fit_time \n", "0 3 0.338711 0.015633 \n", "1 1 0.250874 0.049119 \n", "2 1 0.153489 0.008385 \n", "3 1 0.497521 0.316641 \n", "4 2 0.129465 0.062709 " ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cv_results = pd.DataFrame(search.cv_results_)\n", "cv_results.head()" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:17.341062Z", "iopub.status.busy": "2021-01-14T10:49:17.340118Z", "iopub.status.idle": "2021-01-14T10:49:17.388057Z", "shell.execute_reply": "2021-01-14T10:49:17.388387Z" } }, "outputs": [ { "data": { "text/plain": [ "0.5706" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search.score(X_test, y_test)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:17.391766Z", "iopub.status.busy": "2021-01-14T10:49:17.390740Z", "iopub.status.idle": "2021-01-14T10:49:17.434240Z", "shell.execute_reply": "2021-01-14T10:49:17.435062Z" } }, "outputs": [ { "data": { "text/html": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Array Chunk
Bytes 40.00 kB 40.00 kB
Shape (5000,) (5000,)
Count 2 Tasks 1 Chunks
Type int64 numpy.ndarray
\n", "
\n", "\n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " 5000\n", " 1\n", "\n", "
" ], "text/plain": [ "dask.array<_predict, shape=(5000,), dtype=int64, chunksize=(5000,), chunktype=numpy.ndarray>" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:17.438410Z", "iopub.status.busy": "2021-01-14T10:49:17.437408Z", "iopub.status.idle": "2021-01-14T10:49:17.500373Z", "shell.execute_reply": "2021-01-14T10:49:17.499763Z" } }, "outputs": [ { "data": { "text/plain": [ "array([1, 1, 1, ..., 1, 0, 1])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "search.predict(X_test).compute()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It also has some other attributes." ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "execution": { "iopub.execute_input": "2021-01-14T10:49:17.513827Z", "iopub.status.busy": "2021-01-14T10:49:17.513383Z", "iopub.status.idle": "2021-01-14T10:49:17.546470Z", "shell.execute_reply": "2021-01-14T10:49:17.546787Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
model_idparamspartial_fit_callspartial_fit_timescorescore_timeelapsed_wall_timebracket
0bracket=1-0{'hidden_layer_sizes': (6, 6, 6, 6), 'batch_si...10.3543440.48940.0270380.5148151
1bracket=1-1{'hidden_layer_sizes': (12, 12), 'batch_size':...10.2999930.51060.0340500.5148161
2bracket=1-2{'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba...10.1618740.51060.0167660.5148161
3bracket=1-1{'hidden_layer_sizes': (12, 12), 'batch_size':...20.2017550.51060.0355240.8206381
4bracket=1-2{'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba...20.1451040.51060.0171760.8206391
\n", "
" ], "text/plain": [ " model_id params \\\n", "0 bracket=1-0 {'hidden_layer_sizes': (6, 6, 6, 6), 'batch_si... \n", "1 bracket=1-1 {'hidden_layer_sizes': (12, 12), 'batch_size':... \n", "2 bracket=1-2 {'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba... \n", "3 bracket=1-1 {'hidden_layer_sizes': (12, 12), 'batch_size':... \n", "4 bracket=1-2 {'hidden_layer_sizes': (4, 4, 4, 4, 4, 4), 'ba... \n", "\n", " partial_fit_calls partial_fit_time score score_time elapsed_wall_time \\\n", "0 1 0.354344 0.4894 0.027038 0.514815 \n", "1 1 0.299993 0.5106 0.034050 0.514816 \n", "2 1 0.161874 0.5106 0.016766 0.514816 \n", "3 2 0.201755 0.5106 0.035524 0.820638 \n", "4 2 0.145104 0.5106 0.017176 0.820639 \n", "\n", " bracket \n", "0 1 \n", "1 1 \n", "2 1 \n", "3 1 \n", "4 1 " ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "hist = pd.DataFrame(search.history_)\n", "hist.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This illustrates the history after every `partial_fit` call. There's also an attributed `model_history_` that records the history for each model (it's a reorganization of `history_`)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learn more" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This notebook covered basic usage `HyperbandSearchCV`. The following documentation and resources might be useful to learn more about `HyperbandSearchCV`, including some of the finer use cases:\n", "\n", "* [A talk](https://www.youtube.com/watch?v=x67K9FiPFBQ) introducing `HyperbandSearchCV` to the SciPy 2019 audience and the [corresponding paper](https://conference.scipy.org/proceedings/scipy2019/pdfs/scott_sievert.pdf)\n", "* [HyperbandSearchCV's documentation](https://ml.dask.org/modules/generated/dask_ml.model_selection.HyperbandSearchCV.html)\n", "\n", "Performance comparisons can be found in the SciPy 2019 talk/paper." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 4 }