Train Models on Large Datasets

Live Notebook

You can run this notebook in a live session Binder or view it on Github.

Train Models on Large Datasets

Most estimators in scikit-learn are designed to work with NumPy arrays or scipy sparse matricies. These data structures must fit in the RAM on a single machine.

Estimators implemented in Dask-ML work well with Dask Arrays and DataFrames. This can be much larger than a single machine’s RAM. They can be distributed in memory on a cluster of machines.

%matplotlib inline
from dask.distributed import Client

# Scale up: connect to your own cluster with more resources
# see
client = Client(processes=False, threads_per_worker=4,
                n_workers=1, memory_limit='2GB')



Connection method: Cluster object Cluster type: distributed.LocalCluster

Cluster Info

import dask_ml.datasets
import dask_ml.cluster
import matplotlib.pyplot as plt

In this example, we’ll use dask_ml.datasets.make_blobs to generate some random dask arrays.

# Scale up: increase n_samples or n_features
X, y = dask_ml.datasets.make_blobs(n_samples=1000000,
X = X.persist()
Array Chunk
Bytes 15.26 MiB 1.53 MiB
Shape (1000000, 2) (100000, 2)
Count 10 Tasks 10 Chunks
Type float64 numpy.ndarray
2 1000000

We’ll use the k-means implemented in Dask-ML to cluster the points. It uses the k-means|| (read: “k-means parallel”) initialization algorithm, which scales better than k-means++. All of the computation, both during and after initialization, can be done in parallel.

km = dask_ml.cluster.KMeans(n_clusters=3, init_max_iter=2, oversampling_factor=10)
/usr/share/miniconda3/envs/dask-examples/lib/python3.9/site-packages/dask/ UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.
KMeans(init_max_iter=2, n_clusters=3, oversampling_factor=10)

We’ll plot a sample of points, colored by the cluster each falls into.

fig, ax = plt.subplots()
ax.scatter(X[::1000, 0], X[::1000, 1], marker='.', c=km.labels_[::1000],
           cmap='viridis', alpha=0.25);

For all the estimators implemented in Dask-ML, see the API documentation.