Hyperparameter optimization is the process of deducing model parameters that can’t be learned from data. This process is often time- and resource-consuming, especially in the context of deep learning. A good description of this process can be found at “Tuning the hyper-parameters of an estimator,” and the issues that arise are concisely summarized in Dask-ML’s documentation of “Hyper Parameter Searches.”

There’s a host of libraries and frameworks out there to address this problem. Scikit-Learn’s module has been mirrored in Dask-ML and auto-sklearn, both of which offer advanced hyperparameter optimization techniques. Other implementations that don’t follow the Scikit-Learn interface include Ray Tune, AutoML and Optuna.

Ray recently provided a wrapper to Ray Tune that mirrors the Scikit-Learn API called tune-sklearn (docs, source). The introduction of this library states the following:

Cutting edge hyperparameter tuning techniques (Bayesian optimization, early stopping, distributed execution) can provide significant speedups over grid search and random search.

However, the machine learning ecosystem is missing a solution that provides users with the ability to leverage these new algorithms while allowing users to stay within the Scikit-Learn API. In this blog post, we introduce tune-sklearn [Ray’s tuning library] to bridge this gap. Tune-sklearn is a drop-in replacement for Scikit-Learn’s model selection module with state-of-the-art optimization features.

GridSearchCV 2.0 — New and Improved

This claim is inaccurate: for over a year Dask-ML has provided access to “cutting edge hyperparameter tuning techniques” with a Scikit-Learn compatible API. To correct their statement, let’s look at each of the features that Ray’s tune-sklearn provides, and compare them to Dask-ML:

Here’s what [Ray’s] tune-sklearn has to offer:

  1. Consistency with Scikit-Learn API
  2. Modern hyperparameter tuning techniques
  3. Framework support
  4. Scale up … [to] multiple cores and even multiple machines.

[Ray’s] Tune-sklearn is also fast.

Dask-ML’s model selection module has every one of the features:

  • Consistency with Scikit-Learn API: Dask-ML’s model selection API mirrors the Scikit-Learn model selection API.
  • Modern hyperparameter tuning techniques: Dask-ML offers state-of-the-art hyperparameter tuning techniques.
  • Framework support: Dask-ML model selection supports many libraries including Scikit-Learn, PyTorch, Keras, LightGBM and XGBoost.
  • Scale up: Dask-ML supports distributed tuning (how could it not?) and larger-than-memory datasets.

Dask-ML is also fast. In “Speed” we show a benchmark between Dask-ML, Ray and Scikit-Learn:

Only time-to-solution is relevant; all of these methods produce similar model scores. See “Speed” for details.

Now, let’s walk through the details on how to use Dask-ML to obtain the 5 features above.

Consistency with the Scikit-Learn API

Dask-ML is consistent with the Scikit-Learn API.

Here’s how to use Scikit-Learn’s, Dask-ML’s and Ray’s tune-sklearn hyperparameter optimization:

## Trimmed example; see appendix for more detail
from sklearn.model_selection import RandomizedSearchCV
search = RandomizedSearchCV(model, params, ...)
search.fit(X, y)

from dask_ml.model_selection import HyperbandSearchCV
search = HyperbandSearchCV(model, params, ...)
search.fit(X, y, classes=[0, 1])

from tune_sklearn import TuneSearchCV
search = TuneSearchCV(model, params, ...)
search.fit(X, y, classes=[0, 1])

The definitions of model and params follow the normal Scikit-Learn definitions as detailed in the appendix.

Clearly, both Dask-ML and Ray’s tune-sklearn are Scikit-Learn compatible. Now let’s focus on how each search performs and how it’s configured.

Modern hyperparameter tuning techniques

Dask-ML offers state-of-the-art hyperparameter tuning techniques in a Scikit-Learn interface.

The introduction of Ray’s tune-sklearn made this claim:

tune-sklearn is the only Scikit-Learn interface that allows you to easily leverage Bayesian Optimization, HyperBand and other optimization techniques by simply toggling a few parameters.

The state-of-the-art in hyperparameter optimization is currently “Hyperband.” Hyperband reduces the amount of computation required with a principled early stopping scheme; past that, it’s the same as Scikit-Learn’s popular RandomizedSearchCV.

Hyperband works. As such, it’s very popular. After the introduction of Hyperband in 2016 by Li et. al, the paper has been cited over 470 times and has been implemented in many different libraries including Dask-ML, Ray Tune, keras-tune, Optuna, AutoML,1 and Microsoft’s NNI. The original paper shows a rather drastic improvement over all the relevant implementations,2 and this drastic improvement persists in follow-up works.3 Some illustrative results from Hyperband are below:

All algorithms are configured to do the same amount of work except “random 2x” which does twice as much work. “hyperband (finite)” is similar Dask-ML’s default implementation, and “bracket s=4” is similar to Ray’s default implementation. “random” is a random search. SMAC,4 spearmint,5 and TPE6 are popular Bayesian algorithms.

Hyperband is undoubtedly a “cutting edge” hyperparameter optimization technique. Dask-ML and Ray offer Scikit-Learn implementations of this algorithm that rely on similar implementations, and Dask-ML’s implementation also has a rule of thumb for configuration. Both Dask-ML’s and Ray’s documentation encourages use of Hyperband.

Ray does support using their Hyperband implementation on top of a technique called Bayesian sampling. This changes the hyperparameter sampling scheme for model initialization. This can be used in conjunction with Hyperband’s early stopping scheme. Adding this option to Dask-ML’s Hyperband implementation is future work for Dask-ML.

Framework support

Dask-ML model selection supports many libraries including Scikit-Learn, PyTorch, Keras, LightGBM and XGBoost.

Ray’s tune-sklearn supports these frameworks:

tune-sklearn is used primarily for tuning Scikit-Learn models, but it also supports and provides examples for many other frameworks with Scikit-Learn wrappers such as Skorch (Pytorch), KerasClassifiers (Keras), and XGBoostClassifiers (XGBoost).

Clearly, both Dask-ML and Ray support the many of the same libraries.

However, both Dask-ML and Ray have some qualifications. Certain libraries don’t offer an implementation of partial_fit,7 so not all of the modern hyperparameter optimization techniques can be offered. Here’s a table comparing different libraries and their support in Dask-ML’s model selection and Ray’s tune-sklearn:

Model Library Dask-ML support Ray support Dask-ML: early stopping? Ray: early stopping?
Scikit-Learn ✔* ✔*
PyTorch (via Skorch)
Keras (via SciKeras) ✔** ✔**
LightGBM
XGBoost

* Only for the models that implement partial_fit.
** Thanks to work by the Dask developers around scikeras#24.

By this measure, Dask-ML and Ray model selection have the same level of framework support. Of course, Dask has tangential integration with LightGBM and XGBoost through Dask-ML’s xgboost module and dask-lightgbm.

Scale up

Dask-ML supports distributed tuning (how could it not?), aka parallelization across multiple machines/cores. In addition, it also supports larger-than-memory data.

[Ray’s] Tune-sklearn leverages Ray Tune, a library for distributed hyperparameter tuning, to efficiently and transparently parallelize cross validation on multiple cores and even multiple machines.

Naturally, Dask-ML also scales to multiple cores/machines because it relies on Dask. Dask has wide support for different deployment options that span from your personal machine to supercomputers. Dask will very likely work on top of any computing system you have available, including Kubernetes, SLURM, YARN and Hadoop clusters as well as your personal machine.

Dask-ML’s model selection also scales to larger-than-memory datasets, and is thoroughly tested. Support for larger-than-memory data is untested in Ray, and there are no examples detailing how to use Ray Tune with the distributed dataset implementations in PyTorch/Keras.

In addition, I have benchmarked Dask-ML’s model selection module to see how the time-to-solution is affected by the number of Dask workers in “Better and faster hyperparameter optimization with Dask.” That is, how does the time to reach a particular accuracy scale with the number of workers $P$? At first, it’ll scale like $1/P$ but with large number of workers the serial portion will dictate time to solution according to Amdahl’s Law. Briefly, I found Dask-ML’s HyperbandSearchCV speedup started to saturate around 24 workers for a particular search.

Speed

Both Dask-ML and Ray are much faster than Scikit-Learn.

Ray’s tune-sklearn runs some benchmarks in the introduction with the GridSearchCV class found in Scikit-Learn and Dask-ML. A more fair benchmark would be use Dask-ML’s HyperbandSearchCV because it is almost the same as the algorithm in Ray’s tune-sklearn. To be specific, I’m interested in comparing these methods:

  • Scikit-Learn’s RandomizedSearchCV. This is a popular implementation, one that I’ve bootstrapped myself with a custom model.
  • Dask-ML’s HyperbandSearchCV. This is an early stopping technique for RandomizedSearchCV.
  • Ray tune-sklearn’s TuneSearchCV. This is a slightly different early stopping technique than HyperbandSearchCV’s.

Each search is configured to perform the same task: sample 100 parameters and train for no longer than 100 “epochs” or passes through the data.8 Each estimator is configured as their respective documentation suggests. Each search uses 8 workers with a single cross validation split, and a partial_fit call takes one second with 50,000 examples. The complete setup can be found in the appendix.

Here’s how long each library takes to complete the same search:

Notably, we didn’t improve the Dask-ML codebase for this benchmark, and ran the code as it’s been for the last year.9 Regardless, it’s possible that other artifacts from biased benchmarks crept into this benchmark.

Clearly, Ray and Dask-ML offer similar performance for 8 workers when compared with Scikit-Learn. To Ray’s credit, their implementation is ~15% faster than Dask-ML’s with 8 workers. We suspect that this performance boost comes from the fact that Ray implements an asynchronous variant of Hyperband. We should investigate this difference between Dask and Ray, and how each balances the tradeoffs, number FLOPs vs. time-to-solution. This will vary with the number of workers: the asynchronous variant of Hyperband provides no benefit if used with a single worker.

Dask-ML reaches scores quickly in serial environments, or when the number of workers is small. Dask-ML prioritizes fitting high scoring models: if there are 100 models to fit and only 4 workers available, Dask-ML selects the models that have the highest score. This is most relevant in serial environments;10 see “Better and faster hyperparameter optimization with Dask” for benchmarks. This feature is omitted from this benchmark, which only focuses on time to solution.

Conclusion

Dask-ML and Ray offer the same features for model selection: state-of-the-art features with a Scikit-Learn compatible API, and both implementations have fairly wide support for different frameworks and rely on backends that can scale to many machines.

In addition, the Ray implementation has provided motivation for further development, specifically on the following items:

  1. Adding support for more libraries, including Keras (dask-ml#696, dask-ml#713, scikeras#24). SciKeras is a Scikit-Learn wrapper for Keras that (now) works with Dask-ML model selection because SciKeras models implement the Scikit-Learn model API.
  2. Better documenting the models that Dask-ML supports (dask-ml#699). Dask-ML supports any model that implement the Scikit-Learn interface, and there are wrappers for Keras, PyTorch, LightGBM and XGBoost. Now, Dask-ML’s documentation prominently highlights this fact.

The Ray implementation has also helped motivate and clarify future work. Dask-ML should include the following implementations:

  1. A Bayesian sampling scheme for the Hyperband implementation that’s similar to Ray’s and BOHB’s (dask-ml#697).
  2. A configuration of HyperbandSearchCV that’s well-suited for exploratory hyperparameter searches. An initial implementation is in dask-ml#532, which should be benchmarked against Ray.

Luckily, all of these pieces of development are straightforward modifications because the Dask-ML model selection framework is pretty flexible.

Thank you Tom Augspurger, Matthew Rocklin, Julia Signell, and Benjamin Zaitlen for your feedback, suggestions and edits.

Appendix

Benchmark setup

This is the complete setup for the benchmark between Dask-ML, Scikit-Learn and Ray. Complete details can be found at stsievert/dask-hyperband-comparison.

Let’s create a dummy model that takes 1 second for a partial_fit call with 50,000 examples. This is appropriate for this benchmark; we’re only interested in the time required to finish the search, not how well the models do. Scikit-learn, Ray and Dask-ML have have very similar methods of choosing hyperparameters to evaluate; they differ in their early stopping techniques.

from scipy.stats import uniform
from sklearn.model_selection import make_classification
from benchmark import ConstantFunction  # custom module

# This model sleeps for `latency * len(X)` seconds before
# reporting a score of `value`.
model = ConstantFunction(latency=1 / 50e3, max_iter=max_iter)

params = {"value": uniform(0, 1)}
# This dummy dataset mirrors the MNIST dataset
X_train, y_train = make_classification(n_samples=int(60e3), n_features=784)

This model will take 2 minutes to train for 100 epochs (aka passes through the data). Details can be found at stsievert/dask-hyperband-comparison.

Let’s configure our searches to use 8 workers with a single cross-validation split:

from sklearn.model_selection import RandomizedSearchCV, ShuffleSplit
split = ShuffleSplit(test_size=0.2, n_splits=1)
kwargs = dict(cv=split, refit=False)

search = RandomizedSearchCV(model, params, n_jobs=8, n_iter=n_params, **kwargs)
search.fit(X_train, y_train)  # 20.88 minutes

from dask_ml.model_selection import HyperbandSearchCV
dask_search = HyperbandSearchCV(
    model, params, test_size=0.2, max_iter=max_iter, aggressiveness=4
)

from tune_sklearn import TuneSearchCV
ray_search = TuneSearchCV(
    model, params, n_iter=n_params, max_iters=max_iter, early_stopping=True, **kwargs
)

dask_search.fit(X_train, y_train)  # 2.93 minutes
ray_search.fit(X_train, y_train)  # 2.49 minutes

Full example usage

from sklearn.linear_model import SGDClassifier
from scipy.stats import uniform, loguniform
from sklearn.datasets import make_classification
model = SGDClassifier()
params = {"alpha": loguniform(1e-5, 1e-3), "l1_ratio": uniform(0, 1)}
X, y = make_classification()

from sklearn.model_selection import RandomizedSearchCV
search = RandomizedSearchCV(model, params, ...)
search.fit(X, y)

from dask_ml.model_selection import HyperbandSearchCV
HyperbandSearchCV(model, params, ...)
search.fit(X, y, classes=[0, 1])

from tune_sklearn import TuneSearchCV
search = TuneSearchCV(model, params, ...)
search.fit(X, y, classes=[0, 1])

  1. Their implementation of Hyperband in HpBandSter is included in Auto-PyTorch and BOAH

  2. See Figures 4, 7 and 8 in “Hyperband: A Novel Bandit-Based Approach to Hyperparameter Optimization.” 

  3. See Figure 1 of the BOHB paper and a paper from an augmented reality company. 

  4. SMAC is described in “Sequential Model-Based Optimization forGeneral Algorithm Configuration,” and is available in AutoML

  5. Spearmint is described in “Practical Bayesian Optimization of MachineLearning Algorithms,” and is available in HIPS/spearmint

  6. TPE is described in Section 4 of “Algorithms for Hyperparameter Optimization,” and is available through Hyperopt

  7. From Ray’s README.md: “If the estimator does not support partial_fit, a warning will be shown saying early stopping cannot be done and it will simply run the cross-validation on Ray’s parallel back-end.” 

  8. Despite a relevant implementation in dask-ml#527

  9. Because priority is meaningless if there are an infinite number of workers. 


blog comments powered by Disqus