{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Generalized Linear Models\n", "\n", "This notebook introduces the algorithms within [Dask-GLM](https://github.com/dask/dask-glm) for [Generalized Linear Models](https://en.wikipedia.org/wiki/Generalized_linear_model)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Start Dask Client for Dashboard\n", "\n", "Starting the Dask Client is optional. It will provide a dashboard which \n", "is useful to gain insight on the computation. \n", "\n", "The link to the dashboard will become visible when you create the client below. We recommend having it open on one side of your screen while using your notebook on the other side. This can take some effort to arrange your windows, but seeing them both at the same is very useful when learning." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:21:46.509501Z", "iopub.status.busy": "2022-07-27T19:21:46.509058Z", "iopub.status.idle": "2022-07-27T19:21:47.686732Z", "shell.execute_reply": "2022-07-27T19:21:47.685976Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Client

\n", "

Client-5b571f45-0de1-11ed-a361-000d3a8f7959

\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
Connection method: Cluster objectCluster type: distributed.LocalCluster
\n", " Dashboard: http://10.1.1.64:8787/status\n", "
\n", "\n", " \n", "
\n", "

Cluster Info

\n", "
\n", "
\n", "
\n", "
\n", "

LocalCluster

\n", "

94f79b63

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "\n", " \n", "
\n", " Dashboard: http://10.1.1.64:8787/status\n", " \n", " Workers: 1\n", "
\n", " Total threads: 4\n", " \n", " Total memory: 1.86 GiB\n", "
Status: runningUsing processes: False
\n", "\n", "
\n", " \n", "

Scheduler Info

\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", "

Scheduler

\n", "

Scheduler-079365e3-ef5b-4539-84b1-973599540812

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Comm: inproc://10.1.1.64/9057/1\n", " \n", " Workers: 1\n", "
\n", " Dashboard: http://10.1.1.64:8787/status\n", " \n", " Total threads: 4\n", "
\n", " Started: Just now\n", " \n", " Total memory: 1.86 GiB\n", "
\n", "
\n", "
\n", "\n", "
\n", " \n", "

Workers

\n", "
\n", "\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: 0

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: inproc://10.1.1.64/9057/4\n", " \n", " Total threads: 4\n", "
\n", " Dashboard: http://10.1.1.64:39345/status\n", " \n", " Memory: 1.86 GiB\n", "
\n", " Nanny: None\n", "
\n", " Local directory: /home/runner/work/dask-examples/dask-examples/machine-learning/dask-worker-space/worker-ptl8ho2_\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dask.distributed import Client, progress\n", "client = Client(processes=False, threads_per_worker=4,\n", " n_workers=1, memory_limit='2GB')\n", "client" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make a random dataset" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:21:47.690122Z", "iopub.status.busy": "2022-07-27T19:21:47.689737Z", "iopub.status.idle": "2022-07-27T19:21:47.926753Z", "shell.execute_reply": "2022-07-27T19:21:47.925891Z" } }, "outputs": [ { "data": { "text/html": [ "\n", " \n", " \n", " \n", " \n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Array Chunk
Bytes 152.59 MiB 7.63 MiB
Shape (200000, 100) (10000, 100)
Count 20 Tasks 20 Chunks
Type float64 numpy.ndarray
\n", "
\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", "\n", " \n", " \n", "\n", " \n", " 100\n", " 200000\n", "\n", "
" ], "text/plain": [ "dask.array" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from dask_glm.datasets import make_regression\n", "X, y = make_regression(n_samples=200000, n_features=100, n_informative=5, chunksize=10000)\n", "X" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:21:47.929938Z", "iopub.status.busy": "2022-07-27T19:21:47.929646Z", "iopub.status.idle": "2022-07-27T19:21:48.100130Z", "shell.execute_reply": "2022-07-27T19:21:48.096704Z" } }, "outputs": [], "source": [ "import dask\n", "X, y = dask.persist(X, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solve with a GLM algorithm\n", "\n", "*We also recommend looking at the \"Graph\" dashboard during execution if available*" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:21:48.103783Z", "iopub.status.busy": "2022-07-27T19:21:48.103577Z", "iopub.status.idle": "2022-07-27T19:21:59.768184Z", "shell.execute_reply": "2022-07-27T19:21:59.767624Z" } }, "outputs": [], "source": [ "import dask_glm.algorithms\n", "\n", "b = dask_glm.algorithms.admm(X, y, max_iter=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solve with a difference GLM algorithm" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:21:59.772140Z", "iopub.status.busy": "2022-07-27T19:21:59.771482Z", "iopub.status.idle": "2022-07-27T19:22:04.451244Z", "shell.execute_reply": "2022-07-27T19:22:04.450652Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/dask/core.py:119: RuntimeWarning: overflow encountered in exp\n", " return func(*(_execute_task(a, cache) for a in args))\n" ] } ], "source": [ "b = dask_glm.algorithms.proximal_grad(X, y, max_iter=5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Customizable with different families and regularizers\n", "\n", "The Dask-GLM project is nicely modular, allowing for different GLM families and regularizers, including a relatively straightforward interface for implementing custom ones." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:04.457858Z", "iopub.status.busy": "2022-07-27T19:22:04.455705Z", "iopub.status.idle": "2022-07-27T19:22:08.991631Z", "shell.execute_reply": "2022-07-27T19:22:08.991093Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/dask/core.py:119: RuntimeWarning: overflow encountered in exp\n", " return func(*(_execute_task(a, cache) for a in args))\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/dask/core.py:119: RuntimeWarning: overflow encountered in exp\n", " return func(*(_execute_task(a, cache) for a in args))\n" ] } ], "source": [ "import dask_glm.families\n", "import dask_glm.regularizers\n", "\n", "family = dask_glm.families.Poisson()\n", "regularizer = dask_glm.regularizers.ElasticNet()\n", "\n", "b = dask_glm.algorithms.proximal_grad(\n", " X, y, \n", " max_iter=5, \n", " family=family,\n", " regularizer=regularizer,\n", ")" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:08.995266Z", "iopub.status.busy": "2022-07-27T19:22:08.994623Z", "iopub.status.idle": "2022-07-27T19:22:09.066066Z", "shell.execute_reply": "2022-07-27T19:22:09.065491Z" } }, "outputs": [], "source": [ "dask_glm.families.Poisson??" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "execution": { "iopub.execute_input": "2022-07-27T19:22:09.069331Z", "iopub.status.busy": "2022-07-27T19:22:09.068870Z", "iopub.status.idle": "2022-07-27T19:22:09.084687Z", "shell.execute_reply": "2022-07-27T19:22:09.084032Z" } }, "outputs": [], "source": [ "dask_glm.regularizers.ElasticNet??" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 4 }