{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Scale XGBoost\n", "=============\n", "\n", "Dask and XGBoost can work together to train gradient boosted trees in parallel. This notebook shows how to use Dask and XGBoost together.\n", "\n", "XGBoost provides a powerful prediction framework, and it works well in practice. It wins Kaggle contests and is popular in industry because it has good performance and can be easily interpreted (i.e., it's easy to find the important features from a XGBoost model).\n", "\n", "\"Dask \"Dask" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup Dask\n", "We setup a Dask client, which provides performance and progress metrics via the dashboard.\n", "\n", "You can view the dashboard by clicking the link after running the cell." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:01.681109Z", "iopub.status.busy": "2022-06-17T02:51:01.680651Z", "iopub.status.idle": "2022-06-17T02:51:05.005555Z", "shell.execute_reply": "2022-06-17T02:51:05.004594Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Client

\n", "

Client-52ef391c-ede8-11ec-a773-000d3a5c8937

\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
Connection method: Cluster objectCluster type: distributed.LocalCluster
\n", " Dashboard: http://127.0.0.1:8787/status\n", "
\n", "\n", " \n", "
\n", "

Cluster Info

\n", "
\n", "
\n", "
\n", "
\n", "

LocalCluster

\n", "

054da482

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", "
\n", " Dashboard: http://127.0.0.1:8787/status\n", " \n", " Workers: 4\n", "
\n", " Total threads: 4\n", " \n", " Total memory: 6.78 GiB\n", "
Status: runningUsing processes: True
\n", "\n", "
\n", " \n", "

Scheduler Info

\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", "

Scheduler

\n", "

Scheduler-08eaa296-7399-4a95-ac8c-6c2e6d1c1df8

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Comm: tcp://127.0.0.1:37507\n", " \n", " Workers: 4\n", "
\n", " Dashboard: http://127.0.0.1:8787/status\n", " \n", " Total threads: 4\n", "
\n", " Started: Just now\n", " \n", " Total memory: 6.78 GiB\n", "
\n", "
\n", "
\n", "\n", "
\n", " \n", "

Workers

\n", "
\n", "\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 0

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:42055\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:36133/status\n", " \n", " Memory: 1.70 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:37185\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-a8ig60uf\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 1

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:43941\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:34561/status\n", " \n", " Memory: 1.70 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:36329\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-3g1qehl6\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 2

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:33943\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:36707/status\n", " \n", " Memory: 1.70 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:35927\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-q634dz7r\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 3

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tcp://127.0.0.1:42971\n", " \n", " Total threads: 1\n", "
\n", " Dashboard: http://127.0.0.1:34239/status\n", " \n", " Memory: 1.70 GiB\n", "
\n", " Nanny: tcp://127.0.0.1:46525\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-f9dd7gr0\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dask.distributed import Client\n", "\n", "client = Client(n_workers=4, threads_per_worker=1)\n", "client" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we create a bunch of synthetic data, with 100,000 examples and 20 features." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:05.008874Z", "iopub.status.busy": "2022-06-17T02:51:05.008599Z", "iopub.status.idle": "2022-06-17T02:51:05.822738Z", "shell.execute_reply": "2022-06-17T02:51:05.821707Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/dask/base.py:1283: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.\n", " warnings.warn(\n" ] }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Array Chunk
Bytes 15.26 MiB 156.25 kiB
Shape (100000, 20) (1000, 20)
Count 100 Tasks 100 Chunks
Type float64 numpy.ndarray
\n", "
\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " 20\n", " 100000\n", "\n", "
" ], "text/plain": [ "dask.array" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dask_ml.datasets import make_classification\n", "\n", "X, y = make_classification(n_samples=100000, n_features=20,\n", " chunks=1000, n_informative=4,\n", " random_state=0)\n", "X" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dask-XGBoost works with both arrays and dataframes. For more information on creating dask arrays and dataframes from real data, see documentation on [Dask arrays](https://dask.pydata.org/en/latest/array-creation.html) or [Dask dataframes](https://dask.pydata.org/en/latest/dataframe-create.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Split data for training and testing\n", "We split our dataset into training and testing data to aid evaluation by making sure we have a fair test:" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:05.826227Z", "iopub.status.busy": "2022-06-17T02:51:05.825795Z", "iopub.status.idle": "2022-06-17T02:51:06.087747Z", "shell.execute_reply": "2022-06-17T02:51:06.086929Z" } }, "outputs": [], "source": [ "from dask_ml.model_selection import train_test_split\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's try to do something with this data using [dask-xgboost][dxgb].\n", "\n", "[dxgb]:https://github.com/dask/dask-xgboost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train Dask-XGBoost" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:06.091911Z", "iopub.status.busy": "2022-06-17T02:51:06.091458Z", "iopub.status.idle": "2022-06-17T02:51:06.138424Z", "shell.execute_reply": "2022-06-17T02:51:06.137553Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/xgboost/compat.py:36: FutureWarning: pandas.Int64Index is deprecated and will be removed from pandas in a future version. Use pandas.Index with the appropriate dtype instead.\n", " from pandas import MultiIndex, Int64Index\n" ] } ], "source": [ "import dask\n", "import xgboost\n", "import dask_xgboost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "dask-xgboost is a small wrapper around xgboost. Dask sets XGBoost up, gives XGBoost data and lets XGBoost do it's training in the background using all the workers Dask has available." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's do some training:" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:06.141271Z", "iopub.status.busy": "2022-06-17T02:51:06.141078Z", "iopub.status.idle": "2022-06-17T02:51:11.966535Z", "shell.execute_reply": "2022-06-17T02:51:11.965906Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Exception in thread Thread-4:\n", "Traceback (most recent call last):\n", " File \"/usr/share/miniconda3/envs/dask-examples/lib/python3.9/threading.py\", line 973, in _bootstrap_inner\n", " self.run()\n", " File \"/usr/share/miniconda3/envs/dask-examples/lib/python3.9/threading.py\", line 910, in run\n", " self._target(*self._args, **self._kwargs)\n", " File \"/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/dask_xgboost/tracker.py\", line 365, in join\n", " while self.thread.isAlive():\n", "AttributeError: 'Thread' object has no attribute 'isAlive'\n" ] } ], "source": [ "params = {'objective': 'binary:logistic',\n", " 'max_depth': 4, 'eta': 0.01, 'subsample': 0.5, \n", " 'min_child_weight': 0.5}\n", "\n", "bst = dask_xgboost.train(client, params, X_train, y_train, num_boost_round=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualize results" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `bst` object is a regular `xgboost.Booster` object. " ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:11.973860Z", "iopub.status.busy": "2022-06-17T02:51:11.973201Z", "iopub.status.idle": "2022-06-17T02:51:11.981268Z", "shell.execute_reply": "2022-06-17T02:51:11.980541Z" } }, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "bst" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This means all the methods mentioned in the [XGBoost documentation][2] are available. We show two examples to expand on this, but these examples are of XGBoost instead of Dask.\n", "\n", "[2]:https://xgboost.readthedocs.io/en/latest/python/python_intro.html#" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot feature importance" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:11.986292Z", "iopub.status.busy": "2022-06-17T02:51:11.985368Z", "iopub.status.idle": "2022-06-17T02:51:12.823144Z", "shell.execute_reply": "2022-06-17T02:51:12.822411Z" } }, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAciklEQVR4nO3de3xU9Z3/8dcHglcEtGKXEBAxysaARQpVt1ajtcYGtWKVxepWFGV9yBbt6tra7qr01wrbXy3oD2t/WF0vWNhVuS0C9Qa69UZBqWKoxTa0SfBeBblYSfjsH+cLDiEJE8jMGfJ9Px+PeWTOOTPnvGcmM+8558ycMXdHRETi1SntACIiki4VgYhI5FQEIiKRUxGIiERORSAiEjkVgYhI5FQE0iZm9iUzez3tHM0xswozq2tl+hfNbLWZbTCzc/MYbY+Z2ffM7Bdp55COSUUQCTNbY2abw4vgttPULK7nZla6bdjd/8fdB+Qo471m9sNczDv4ATDV3bu6+5w9mVG4P09vn1i75u63uPvl+Vpea8zsZjObnnYOaT9FaQeQvDrb3Z9IO0SKDgdeSzsEgJkVuXtD2jnaysz0mtERubtOEZyANcDpLUwrBZ4G1gHvAf8Zxj8DOLAR2AD8PVAB1DWZ778Ar4TL3Q18FlgIfAQ8ARyccfmHgLfCsp4BysP4scAW4JOwrP8O44uBR4B3gRpgfMa89gfuBT4AqkOOuhZu4x+ArcDmMP99ge4h75tAPfBDoHO4/JHAU8D74T55EOgRpj3QZF7XN71fmt7nwM3Aw8B0YD1weWvLbyb/zcD0cL5feFwuBWrD7b8SGBYehw9J1ny2XXc08Czw/8L9/jvgyxnTi4F5wF+AN4Armiw3M/c/hcdoS7jtvw2XuxRYFR7zPwL/mDGPCqAOuBZ4J9zeS5s8jrcCfwr5fg3sH6adADwXbtNvgYq0n0sd8ZR6AJ3y9EC3XgQzgO+TbCrcDzgpY5oDpRnDO7zghfm+QPLi3zs80V8CjiN5sX0KuCnj8pcBB4VpU4AVGdPuBX6YMdwJWA7cCOwD9A8vMpVh+iTgf4BDgD7ASloogubuA2AO8P+BA4HDgKXbXsBIyvErIWdPktKa0sq8Kpoum52LYAtwbrhd+7e2/Gay38zORfDz8HidAXwc5ndYxuNwSrj8aKAB+DbQhaTQ1wGHhOlPAz8L8xpMUrpfbiX39iwZ+YaTlKcBpwCbgCEZ900Dyaa5LkBVmH5wmH4HsCTk7gz8Xbjfe5MUcVVY9lfCcM+0n08d7ZR6AJ3y9EAnL0obSN5ZbTtdEabdD0wDSpq5XjZFcFHG8CPAnRnD3wLmtJCpR5h/9zB8LzsWwfHAn5tc5wbgP8L5PwJnZkwbS5ZFQFJcfyW88wzjLgQWt3Ddc4GXm5tXc/dLM8u7GXgmY1pbl7/9xZdPi6B3xvT3gb9v8jhcE86PBtYCljF9KfAPJAXaCByUMW0icG9zuZtmaeW+ngNcnXHfbAaKMqa/Q/Juv1OY9rlm5vEd4IEm434FXJLmc6kjnrS9Ly7nevP7CK4H/g+w1Mw+AG5193vaMN+3M85vbma4K4CZdQZ+BFxA8i57a7jMoSTvUJs6HCg2sw8zxnUmWQuAZJNGbca0P7Uh8+Ek707fNLNt4zptm5+ZHQbcDnyJZA2mE8kmmD2RmbXV5Wcpq/s9qPfwShr8ieT+Kwb+4u4fNZk2tIXczTKzrwI3AUeT3I4DgFczLvK+77hPZFPIdyjJmsgfmpnt4cAFZnZ2xrguwOJd5ZG2UREI7v4WcAWAmZ0EPGFmz7j7G+28qG8AXwNOJ3m33J3kxXXbK2HTQ+HWAjXuflQL83uT5B3tth3AfduQpZbkHfmh3vxO24khz7Hu/n74uGnmp6yaZt1I8uIHbC+9nk0uk3mdXS2/vfU2M8sog74k+wXWAoeY2UEZZdCXZJ/FNk1v6w7DZrYvyRrIN4G57r7FzObw6ePamvdINmsdSbIPIFMtyRrBFVnMR/aAPj4qmNkFZlYSBj8geaI3huG3SbbNt4eDSF783id50bylyfSmy1oKrDez75jZ/mbW2cwGmtmwMP2/gBvM7OCQ/1vZBnH3N4HHgFvNrJuZdTKzI83slIysG4APzaw3yY7o1rL+HtjPzIabWRfgX0m2c+/u8tvbYcB4M+tiZhcAZcACd68l2Rk70cz2M7NjgTEkO8db8jbQz8y2vX7sQ3Jb3wUawtrBGdmEcvetwD3AT82sODzGJ4ZymQ6cbWaVYfx+4bsiJa3PVdpKRRCX/27yPYLZYfww4EUz20DyLvFqd68J024G7jOzD81s5B4u/36SzQ71JJ/yeaHJ9LuBY8Ky5rh7I3A2yQ7MGpJ3j78gWZMAmBDmV0PyovpAG/N8k+RFrJqkAB8GemXMewjJJqtHgVlNrjsR+NeQ9Tp3XwdcFfLVk6whtPjltiyW395eBI4iuQ9/BJzv7u+HaReS7HdYC8wm2bn/eCvzeij8fd/MXgprEuNJivkDkjW/eW3Idh3JZqTfkHxy6d+BTqGkvgZ8j6RkakkKWa9b7cx23GwoIh2NmY0GLnf3k9LOIoVJzSoiEjkVgYhI5LRpSEQkclojEBGJ3F75PYIePXp4aWnpri+YRxs3buTAAw9MO8ZOCjFXIWaCwsxViJlAudqiUDItX778PXdv+t2WRNpfbd6d09FHH+2FZvHixWlHaFYh5irETO6FmasQM7krV1sUSiZgmbfwmqpNQyIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5c/e0M7RZ3/6l3mnkbWnH2MG1gxq49dWitGPspBBzFWImKMxchZgJlKstmmZaM2l4KjnMbLm7D21umtYIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEUnZ5MmTKS8vZ+DAgVx44YV8/PHHPPTQQ5SXl9OpUyeWLVvW4nUXLVrEgAEDKC0tZdKkSbu1/FSKwMzGm9kqM3vQzG43szfM7BUzG5JGHhGRtNTX13P77bezbNkyVq5cSWNjIzNnzmTgwIHMmjWLk08+ucXrNjY2Mm7cOBYuXEh1dTUzZsygurq6zRnSWiO4CqgCHgSOCqexwJ0p5RERSU1DQwObN2+moaGBTZs2UVxcTFlZGQMGDGj1ekuXLqW0tJT+/fuzzz77MGrUKObOndvm5ee9CMzs50B/YB4wG7jfEy8APcysV74ziYikpXfv3lx33XX07duXXr160b17d84444ysrltfX0+fPn22D5eUlFBfX9/mDHkvAne/ElgLnAo8DtRmTK4Dejd3PTMba2bLzGzZhvXrcx9URCQPPvjgA+bOnUtNTQ1r165l48aNTJ8+PavruvtO48yszRnS3lncXOKdbxng7tPcfai7D+3arVuOY4mI5McTTzzBEUccQc+ePenSpQvnnXcezz33XFbXLSkpobb20/fSdXV1FBcXtzlD2kVQB/TJGC4hWVsQEYlC3759eeGFF9i0aRPuzpNPPklZWVlW1x02bBirV6+mpqaGTz75hJkzZ3LOOee0OUPaRTAP+KYlTgDWufubKWcSEcmb448/nvPPP58hQ4YwaNAgtm7dytixY5k9ezYlJSU8//zzDB8+nMrKSgDWrl1LVVUVAEVFRUydOpXKykrKysoYOXIk5eXlbc5Q1K63qO0WkHx66A1gE3BpunFERPJvwoQJTJgwYYdxI0aMYMSIETtdtri4mAULFmwfrqqq2l4MuyuVInD3fhmD49LIICIiibQ3DYmISMpUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIhETkUgIhI5FYGISORUBCIikVMRiIjEzt33utPRRx/thWbx4sVpR2hWIeYqxEzuhZmrEDO5K1dbFEomYJm38JqqNQIRkcipCEREIqciEBGJnIpARCRyKgIRkcipCEREIqciEBGJnIpARCRyKgIRkcipCEREIleUdoDdsXlLI/2++2jaMXZw7aAGRhdYJijMXIWYCXbOtWbS8BTTiOSP1ghERCKnIhARiVxWRWBmR5rZvuF8hZmNN7MeOU0mIiJ5ke0awSNAo5mVAncDRwC/zFkqERHJm2yLYKu7NwAjgCnu/m2gV+5iiYhIvmRbBFvM7ELgEmB+GNclN5FERCSfsi2CS4ETgR+5e42ZHQFMz10sERHJl6y+R+Du1Wb2HaBvGK4BJuUymIiI5Ee2nxo6G1gBLArDg81sXg5ziYhInmS7aehm4AvAhwDuvoLkk0MiIrKXy7YIGtx9XZNx3t5hREQk/7I91tBKM/sG0NnMjgLGA8/lLpaIiORLtmsE3wLKgb+SfJFsHXBNjjKJiEge7XKNwMw6A/Pc/XTg+7mPJCIi+bTLNQJ3bwQ2mVn3POQREZE8y3YfwcfAq2b2OLBx20h3H5+TVCIikjfZFsGj4SQiIh1MVjuL3f2+5k65DidSiBobGznuuOM466yzAHjooYcoLy+nU6dOLFu2rMXrLVq0iAEDBlBaWsqkSfpivhSObL9ZXGNmf2x62t2Fht8zWGVmD4bhYWbWaGbn7+48RfLltttuo6ysbPvwwIEDmTVrFieffHKL12lsbGTcuHEsXLiQ6upqZsyYQXV1dT7iiuxSth8fHQoMC6cvAbezZweduwqocveLwqeS/h341R7MTyQv6urqePTRR7n88su3jysrK2PAgAGtXm/p0qWUlpbSv39/9tlnH0aNGsXcuXNzHVckK9luGno/41Tv7lOA03ZngWb2c6A/MM/Mvk3yHYVHgHd2Z34i+XTNNdfw4x//mE6d2vYrr/X19fTp02f7cElJCfX19e0dT2S3ZLWz2MyGZAx2IllDOGh3FujuV5rZmcCpwL4kX1A7jWRto7UMY4GxAAd/pifddmfhIntg/vz5HHbYYXz+859nyZIlbbqu+85HZDGzdkomsmey/dTQrRnnG4AaYGQ7LH8K8B13b9zVk8LdpwHTAPr2L9VxjiTvnn32WebNm8eCBQv4+OOPWb9+PRdffDHTp+96K2lJSQm1tbXbh+vq6iguLs5lXJGsZVsEY9x9h53D4cdp9tRQYGYogUOBKjNrcPc57TBvkXY1ceJEJk6cCMCSJUv4yU9+klUJAAwbNozVq1dTU1ND7969mTlzJr/8pX72WwpDths6H85yXJu4+xHu3s/d+4X5XaUSkL3N7NmzKSkp4fnnn2f48OFUVlYCsHbtWqqqqgAoKipi6tSpVFZWUlZWxsiRIykvL08ztsh2ra4RmNnfkhxsrruZnZcxqRuwXy6DiRSyiooKKioqABgxYgQjRozY6TLFxcUsWLBg+3BVVdX2YhApJLvaNDQAOAvoAZydMf4j4IrdXWhYA2g6bvTuzk9ERHZfq0Xg7nOBuWZ2ors/n6dMIiKSR9nuLH7ZzMaRbCbavknI3S/LSSoREcmbbHcWPwD8DVAJPA2UkGweEhGRvVy2RVDq7v8GbAwHmxsODMpdLBERyZdsi2BL+PuhmQ0EugP9cpJIRETyKtt9BNPM7GDg34B5QFfgxpylEhGRvMmqCNz9F+Hs0yQHjBMRkQ4i298j+KyZ3W1mC8PwMWY2JrfRREQkH7LdR3Avye8FbDtK1u+Ba3KQR0RE8izbIjjU3f8L2Arg7g1AY85SiYhI3mRbBBvN7DOAA5jZCcC6nKUSEZG8yfZTQ/9M8mmhI83sWaAnoN8XFhHpAHZ19NG+7v5nd3/JzE4hOQidAa+7+5bWrisiInuHXW0ampNx/j/d/TV3X6kSEBHpOHZVBJm/H6nvD4iIdEC7KgJv4byIiHQQu9pZ/DkzW0+yZrB/OE8YdnfvltN0Ldi/S2denzQ8jUW3aMmSJay5qCLtGDspxFyFmAkKN5dIru3qh2k65yuIiIikI9vvEYiISAelIhARiZyKQEQkcioCEZHIqQhERCKnIhARiZyKQEQkcioCEZHIqQhERCKX7e8RFJTNWxrp991H046xg2sHNTC6wDJBernWFNghQESkZVojEBGJnIpARCRyKgIRkcipCEREIqciEBGJnIpARCRyKgIRkcipCEREIqciEBGJnIpARCRyKgIRkcipCEREIqciEBGJnIpARCRyKgIRkcipCEREIqcikLyora3l1FNPpaysjNGjR3PbbbcBsGLFCk444QQGDx7M0KFDWbp0abPXX7RoEQMGDKC0tJRJkyblM7pIh5ezIjCz8Wa2ysweMbPnzeyvZnZdxvQ+ZrY4XOY1M7s6V1kkfUVFRdx6662sWrWKn/3sZ9xxxx1UV1dz/fXXc9NNN7FixQp+8IMfcP311+903cbGRsaNG8fChQuprq5mxowZVFdXp3ArRDqmXP5U5VXAV4GNwOHAuU2mNwDXuvtLZnYQsNzMHnd3PcM7oF69etGrVy8ADjjgAMrKyqivr8fMWL9+PQDr1q2juLh4p+suXbqU0tJS+vfvD8CoUaOYO3cuxxxzTP5ugEgHlpMiMLOfA/2BecA97j7ZzHb4EVt3fxN4M5z/yMxWAb0BFUEH99Zbb/Hyyy9z/PHHM2XKFCorK7nuuuvYunUrzz333E6Xr6+vp0+fPtuHS0pKePHFF/MZWaRDy8mmIXe/ElgLnOruk3d1eTPrBxwHtPjsNrOxZrbMzJZtCO8gZe+zYcMGbrzxRqZMmUK3bt248847mTx5MrW1tUyePJkxY8bsdB1332mcmeUjrkgUUt9ZbGZdgUeAa9y9xVd4d5/m7kPdfWjXbt3yF1DazZYtW/j617/O6aefznnnnQfAfffdt/38BRdc0OzO4pKSEmpra7cP19XVNbsJSUR2T6pFYGZdSErgQXeflWYWyS13Z8yYMZSVlTFy5Mjt44uLi3n66acBeOqppzjqqKN2uu6wYcNYvXo1NTU1fPLJJ8ycOZNzzjknb9lFOrpc7ixulSXr9ncDq9z9p2nlkPx49tlneeCBBxg0aBDz58+na9eu3HLLLdx1111cffXVNDQ0sN9++zFt2jQA1q5dy+WXX86CBQsoKipi6tSpVFZW0tjYyGWXXUZ5eXnKt0ik48h5EZjZ3wDLgG7AVjO7BjgGOBb4B+BVM1sRLv49d1+Q60ySfyeddNL2bf1LliyhoqJi+7Tly5fvdPni4mIWLPj0X6Gqqoqqqqqc5xSJUc6KwN37ZQyWNHORXwPa4ycikrLUdxaLiEi6VAQiIpFTEYiIRE5FICISORWBiEjkVAQiIpFTEYiIRE5FICISORWBiEjkVAQiIpFTEYiIRE5FICISORWBiEjkVAQiIpFTEYiIRE5FICISudR+qnJP7N+lM69PGp52jB0sWbKENRdVpB1jJ4WaS0QKh9YIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyKkIREQipyIQEYmcikBEJHIqAhGRyJm7p52hzczsI+D1tHM0cSjwXtohmlGIuQoxExRmrkLMBMrVFoWS6XB379nchKJ8J2knr7v70LRDZDKzZYWWCQozVyFmgsLMVYiZQLnaohAzNaVNQyIikVMRiIhEbm8tgmlpB2hGIWaCwsxViJmgMHMVYiZQrrYoxEw72Ct3FouISPvZW9cIRESknagIREQit1cVgZmdaWavm9kbZvbdFHPcY2bvmNnKjHGHmNnjZrY6/D04z5n6mNliM1tlZq+Z2dUFkms/M1tqZr8NuSYUQq6QobOZvWxm8wso0xoze9XMVpjZskLIZWY9zOxhM/td+P86sQAyDQj30bbTejO7pgByfTv8n680sxnh/z/1/6td2WuKwMw6A3cAXwWOAS40s2NSinMvcGaTcd8FnnT3o4Anw3A+NQDXunsZcAIwLtw/aef6K3Cau38OGAycaWYnFEAugKuBVRnDhZAJ4FR3H5zx2fO0c90GLHL3vwU+R3KfpZrJ3V8P99Fg4PPAJmB2mrnMrDcwHhjq7gOBzsCoNDNlzd33ihNwIvCrjOEbgBtSzNMPWJkx/DrQK5zvRfKltzTvr7nAVwopF3AA8BJwfNq5gBKSJ+VpwPxCeQyBNcChTcallgvoBtQQPlhSCJmayXgG8GzauYDeQC1wCMmXdeeHbAVzX7V02mvWCPj0Tt6mLowrFJ919zcBwt/D0gpiZv2A44AXCyFX2ASzAngHeNzdCyHXFOB6YGvGuLQzATjwmJktN7OxBZCrP/Au8B9hM9ovzOzAlDM1NQqYEc6nlsvd64GfAH8G3gTWuftjaWbK1t5UBNbMOH32tQkz6wo8Alzj7uvTzgPg7o2erMKXAF8ws4Fp5jGzs4B33H15mjla8EV3H0KyCXScmZ2ccp4iYAhwp7sfB2ykgDZtmNk+wDnAQwWQ5WDga8ARQDFwoJldnG6q7OxNRVAH9MkYLgHWppSlOW+bWS+A8PedfAcwsy4kJfCgu88qlFzbuPuHwBKS/Stp5voicI6ZrQFmAqeZ2fSUMwHg7mvD33dItnl/IeVcdUBdWIsDeJikGFK/r4KvAi+5+9thOM1cpwM17v6uu28BZgF/l3KmrOxNRfAb4CgzOyK8CxgFzEs5U6Z5wCXh/CUk2+jzxswMuBtY5e4/LaBcPc2sRzi/P8mT5Xdp5nL3G9y9xN37kfwfPeXuF6eZCcDMDjSzg7adJ9m+vDLNXO7+FlBrZgPCqC8D1WlmauJCPt0sBOnm+jNwgpkdEJ6PXybZsV4o91XL0t5J0cadMVXA74E/AN9PMccMkm2AW0jeMY0BPkOy83F1+HtInjOdRLKp7BVgRThVFUCuY4GXQ66VwI1hfKq5MvJV8OnO4rTvq/7Ab8PptW3/4wWQazCwLDyGc4CD084Uch0AvA90zxiX9n01geSNzkrgAWDftDNlc9IhJkREIrc3bRoSEZEcUBGIiERORSAiEjkVgYhI5FQEIiKR21t/vF6k3ZlZI/Bqxqhz3X1NSnFE8kYfHxUJzGyDu3fN4/KK3L0hX8sTaYk2DYlkycx6mdkz4fj3K83sS2H8mWb2UvjNhSfDuEPMbI6ZvWJmL5jZsWH8zWY2zcweA+4P37x+xMx+E05fTPEmSqS0aUjkU/uHo6RCcsyYEU2mf4PkUOg/Cr+PcYCZ9QTuAk529xozOyRcdgLwsrufa2anAfeTfEMXkuPnn+Tum83sl8Bkd/+1mfUFfgWU5ewWijRDRSDyqc2eHCW1Jb8B7gkH95vj7ivMrAJ4xt1rANz9L+GyJwFfD+OeMrPPmFn3MG2eu28O508HjkkOTQNANzM7yN0/aq8bJbIrKgKRLLn7M+Gw0MOBB8zs/wIf0vzh0Fs7bPrGjHGdgBMzikEk77SPQCRLZnY4ye8Y3EVypNchwPPAKWZ2RLjMtk1DzwAXhXEVwHve/O9DPAb8U8YyBucovkiLtEYgkr0K4F/MbAuwAfimu78bfklslpl1IjnW/FeAm0l+1esVkt/TvaT5WTIeuCNcroikQK7M6a0QaUIfHxURiZw2DYmIRE5FICISORWBiEjkVAQiIpFTEYiIRE5FICISORWBiEjk/hf75Lh444KkfgAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "ax = xgboost.plot_importance(bst, height=0.8, max_num_features=9)\n", "ax.grid(False, axis=\"y\")\n", "ax.set_title('Estimated feature importance')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We specified that only 4 features were informative while creating our data, and only 3 features show up as important." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot the Receiver Operating Characteristic curve\n", "We can use a fancier metric to determine how well our classifier is doing by plotting the [Receiver Operating Characteristic (ROC) curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic):" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:12.826078Z", "iopub.status.busy": "2022-06-17T02:51:12.825696Z", "iopub.status.idle": "2022-06-17T02:51:12.912622Z", "shell.execute_reply": "2022-06-17T02:51:12.905004Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[02:51:12] WARNING: /home/conda/feedstock_root/build_artifacts/xgboost-split_1645117766796/work/src/learner.cc:1264: Empty dataset at worker: 0\n" ] }, { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Array Chunk
Bytes 58.59 kiB 600 B
Shape (15000,) (150,)
Count 100 Tasks 100 Chunks
Type float32 numpy.ndarray
\n", "
\n", " \n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " 15000\n", " 1\n", "\n", "
" ], "text/plain": [ "dask.array<_predict_part, shape=(15000,), dtype=float32, chunksize=(150,), chunktype=numpy.ndarray>" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y_hat = dask_xgboost.predict(client, bst, X_test).persist()\n", "y_hat" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:12.918906Z", "iopub.status.busy": "2022-06-17T02:51:12.916877Z", "iopub.status.idle": "2022-06-17T02:51:14.911604Z", "shell.execute_reply": "2022-06-17T02:51:14.895763Z" } }, "outputs": [], "source": [ "from sklearn.metrics import roc_curve\n", "\n", "y_test, y_hat = dask.compute(y_test, y_hat)\n", "fpr, tpr, _ = roc_curve(y_test, y_hat)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "execution": { "iopub.execute_input": "2022-06-17T02:51:14.921663Z", "iopub.status.busy": "2022-06-17T02:51:14.920578Z", "iopub.status.idle": "2022-06-17T02:51:15.084702Z", "shell.execute_reply": "2022-06-17T02:51:15.083782Z" } }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import auc\n", "\n", "fig, ax = plt.subplots(figsize=(5, 5))\n", "ax.plot(fpr, tpr, lw=3,\n", " label='ROC Curve (area = {:.2f})'.format(auc(fpr, tpr)))\n", "ax.plot([0, 1], [0, 1], 'k--', lw=2)\n", "ax.set(\n", " xlim=(0, 1),\n", " ylim=(0, 1),\n", " title=\"ROC Curve\",\n", " xlabel=\"False Positive Rate\",\n", " ylabel=\"True Positive Rate\",\n", ")\n", "ax.legend();\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This Receiver Operating Characteristic (ROC) curve tells how well our classifier is doing. We can tell it's doing well by how far it bends the upper-left. A perfect classifier would be in the upper-left corner, and a random classifier would follow the diagonal line.\n", "\n", "The area under this curve is `area = 0.76`. This tells us the probability that our classifier will predict correctly for a randomly chosen instance." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learn more\n", "* Recorded screencast stepping through the real world example above:\n", "* A blogpost on dask-xgboost http://matthewrocklin.com/blog/work/2017/03/28/dask-xgboost\n", "* XGBoost documentation: https://xgboost.readthedocs.io/en/latest/python/python_intro.html#\n", "* Dask-XGBoost documentation: http://ml.dask.org/xgboost.html" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 4 }