Comparing DaskML and Ray Tune's Model Selection Algorithms Modern hyperparameter optimizations, ScikitLearn support, framework support and scaling to many machines.
By Scott Sievert (University of Wisconsin–Madison)
Hyperparameter optimization is the process of deducing model parameters that can’t be learned from data. This process is often time and resourceconsuming, especially in the context of deep learning. A good description of this process can be found at “Tuning the hyperparameters of an estimator,” and the issues that arise are concisely summarized in DaskML’s documentation of “Hyper Parameter Searches.”
There’s a host of libraries and frameworks out there to address this problem. ScikitLearn’s module has been mirrored in DaskML and autosklearn, both of which offer advanced hyperparameter optimization techniques. Other implementations that don’t follow the ScikitLearn interface include Ray Tune, AutoML and Optuna.
Ray recently provided a wrapper to Ray Tune that mirrors the ScikitLearn API called tunesklearn (docs, source). The introduction of this library states the following:
Cutting edge hyperparameter tuning techniques (Bayesian optimization, early stopping, distributed execution) can provide significant speedups over grid search and random search.
However, the machine learning ecosystem is missing a solution that provides users with the ability to leverage these new algorithms while allowing users to stay within the ScikitLearn API. In this blog post, we introduce tunesklearn [Ray’s tuning library] to bridge this gap. Tunesklearn is a dropin replacement for ScikitLearn’s model selection module with stateoftheart optimization features.
This claim is inaccurate: for over a year DaskML has provided access to “cutting edge hyperparameter tuning techniques” with a ScikitLearn compatible API. To correct their statement, let’s look at each of the features that Ray’s tunesklearn provides, and compare them to DaskML:
Here’s what [Ray’s] tunesklearn has to offer:
 Consistency with ScikitLearn API …
 Modern hyperparameter tuning techniques …
 Framework support …
 Scale up … [to] multiple cores and even multiple machines.
[Ray’s] Tunesklearn is also fast.
DaskML’s model selection module has every one of the features:
 Consistency with ScikitLearn API: DaskML’s model selection API mirrors the ScikitLearn model selection API.
 Modern hyperparameter tuning techniques: DaskML offers stateoftheart hyperparameter tuning techniques.
 Framework support: DaskML model selection supports many libraries including ScikitLearn, PyTorch, Keras, LightGBM and XGBoost.
 Scale up: DaskML supports distributed tuning (how could it not?) and largerthanmemory datasets.
DaskML is also fast. In “Speed” we show a benchmark between DaskML, Ray and ScikitLearn:
Only timetosolution is relevant; all of these methods produce similar model scores. See “Speed” for details.
Now, let’s walk through the details on how to use DaskML to obtain the 5 features above.
Consistency with the ScikitLearn API
DaskML is consistent with the ScikitLearn API.
Here’s how to use ScikitLearn’s, DaskML’s and Ray’s tunesklearn hyperparameter optimization:
## Trimmed example; see appendix for more detail
from sklearn.model_selection import RandomizedSearchCV
search = RandomizedSearchCV(model, params, ...)
search.fit(X, y)
from dask_ml.model_selection import HyperbandSearchCV
search = HyperbandSearchCV(model, params, ...)
search.fit(X, y, classes=[0, 1])
from tune_sklearn import TuneSearchCV
search = TuneSearchCV(model, params, ...)
search.fit(X, y, classes=[0, 1])
The definitions of model
and params
follow the normal ScikitLearn
definitions as detailed in the appendix.
Clearly, both DaskML and Ray’s tunesklearn are ScikitLearn compatible. Now let’s focus on how each search performs and how it’s configured.
Modern hyperparameter tuning techniques
DaskML offers stateoftheart hyperparameter tuning techniques in a ScikitLearn interface.
The introduction of Ray’s tunesklearn made this claim:
tunesklearn is the only ScikitLearn interface that allows you to easily leverage Bayesian Optimization, HyperBand and other optimization techniques by simply toggling a few parameters.
The stateoftheart in hyperparameter optimization is currently
“Hyperband.” Hyperband reduces the amount of computation
required with a principled early stopping scheme; past that, it’s the same as
ScikitLearn’s popular RandomizedSearchCV
.
Hyperband works. As such, it’s very popular. After the introduction of Hyperband in 2016 by Li et. al, the paper has been cited over 470 times and has been implemented in many different libraries including DaskML, Ray Tune, kerastune, Optuna, AutoML,^{1} and Microsoft’s NNI. The original paper shows a rather drastic improvement over all the relevant implementations,^{2} and this drastic improvement persists in followup works.^{3} Some illustrative results from Hyperband are below:
^{All algorithms are configured to do the same amount of work except “random 2x” which does twice as much work. “hyperband (finite)” is similar DaskML’s default implementation, and “bracket s=4” is similar to Ray’s default implementation. “random” is a random search. SMAC,4 spearmint,5 and TPE6 are popular Bayesian algorithms. }
Hyperband is undoubtedly a “cutting edge” hyperparameter optimization technique. DaskML and Ray offer ScikitLearn implementations of this algorithm that rely on similar implementations, and DaskML’s implementation also has a rule of thumb for configuration. Both DaskML’s and Ray’s documentation encourages use of Hyperband.
Ray does support using their Hyperband implementation on top of a technique called Bayesian sampling. This changes the hyperparameter sampling scheme for model initialization. This can be used in conjunction with Hyperband’s early stopping scheme. Adding this option to DaskML’s Hyperband implementation is future work for DaskML.
Framework support
DaskML model selection supports many libraries including ScikitLearn, PyTorch, Keras, LightGBM and XGBoost.
Ray’s tunesklearn supports these frameworks:
tunesklearn is used primarily for tuning ScikitLearn models, but it also supports and provides examples for many other frameworks with ScikitLearn wrappers such as Skorch (Pytorch), KerasClassifiers (Keras), and XGBoostClassifiers (XGBoost).
Clearly, both DaskML and Ray support the many of the same libraries.
However, both DaskML and Ray have some qualifications. Certain libraries don’t
offer an implementation of partial_fit
,^{7} so not all of the modern
hyperparameter optimization techniques can be offered. Here’s a table comparing
different libraries and their support in DaskML’s model selection and Ray’s
tunesklearn:
Model Library  DaskML support  Ray support  DaskML: early stopping?  Ray: early stopping? 

ScikitLearn  ✔  ✔  ✔*  ✔* 
PyTorch (via Skorch)  ✔  ✔  ✔  ✔ 
Keras (via SciKeras)  ✔  ✔  ✔**  ✔** 
LightGBM  ✔  ✔  ❌  ❌ 
XGBoost  ✔  ✔  ❌  ❌ 
^{* Only for the models that implement partial_fit.}
^{** Thanks to work by the Dask developers around scikeras#24.}
By this measure, DaskML and Ray model selection have the same level of
framework support. Of course, Dask has tangential integration with LightGBM and
XGBoost through DaskML’s xgboost
module and dasklightgbm.
Scale up
DaskML supports distributed tuning (how could it not?), aka parallelization across multiple machines/cores. In addition, it also supports largerthanmemory data.
[Ray’s] Tunesklearn leverages Ray Tune, a library for distributed hyperparameter tuning, to efficiently and transparently parallelize cross validation on multiple cores and even multiple machines.
Naturally, DaskML also scales to multiple cores/machines because it relies on Dask. Dask has wide support for different deployment options that span from your personal machine to supercomputers. Dask will very likely work on top of any computing system you have available, including Kubernetes, SLURM, YARN and Hadoop clusters as well as your personal machine.
DaskML’s model selection also scales to largerthanmemory datasets, and is thoroughly tested. Support for largerthanmemory data is untested in Ray, and there are no examples detailing how to use Ray Tune with the distributed dataset implementations in PyTorch/Keras.
In addition, I have benchmarked DaskML’s model selection module to see how the
timetosolution is affected by the number of Dask workers in “Better and
faster hyperparameter optimization with Dask.” That is, how does the
time to reach a particular accuracy scale with the number of workers $P$? At
first, it’ll scale like $1/P$ but with large number of workers the serial
portion will dictate time to solution according to Amdahl’s Law. Briefly, I
found DaskML’s HyperbandSearchCV
speedup started to saturate around 24
workers for a particular search.
Speed
Both DaskML and Ray are much faster than ScikitLearn.
Ray’s tunesklearn runs some benchmarks in the introduction with the
GridSearchCV
class found in ScikitLearn and DaskML. A more fair benchmark
would be use DaskML’s HyperbandSearchCV
because it is almost the same as the
algorithm in Ray’s tunesklearn. To be specific, I’m interested in comparing
these methods:
 ScikitLearn’s
RandomizedSearchCV
. This is a popular implementation, one that I’ve bootstrapped myself with a custom model.  DaskML’s
HyperbandSearchCV
. This is an early stopping technique forRandomizedSearchCV
.  Ray tunesklearn’s
TuneSearchCV
. This is a slightly different early stopping technique thanHyperbandSearchCV
’s.
Each search is configured to perform the same task: sample 100 parameters and
train for no longer than 100 “epochs” or passes through the
data.^{8} Each estimator is configured as their respective
documentation suggests. Each search uses 8 workers with a single cross
validation split, and a partial_fit
call takes one second with 50,000
examples. The complete setup can be found in the appendix.
Here’s how long each library takes to complete the same search:
Notably, we didn’t improve the DaskML codebase for this benchmark, and ran the code as it’s been for the last year.^{9} Regardless, it’s possible that other artifacts from biased benchmarks crept into this benchmark.
Clearly, Ray and DaskML offer similar performance for 8 workers when compared with ScikitLearn. To Ray’s credit, their implementation is ~15% faster than DaskML’s with 8 workers. We suspect that this performance boost comes from the fact that Ray implements an asynchronous variant of Hyperband. We should investigate this difference between Dask and Ray, and how each balances the tradeoffs, number FLOPs vs. timetosolution. This will vary with the number of workers: the asynchronous variant of Hyperband provides no benefit if used with a single worker.
DaskML reaches scores quickly in serial environments, or when the number of workers is small. DaskML prioritizes fitting high scoring models: if there are 100 models to fit and only 4 workers available, DaskML selects the models that have the highest score. This is most relevant in serial environments;^{10} see “Better and faster hyperparameter optimization with Dask” for benchmarks. This feature is omitted from this benchmark, which only focuses on time to solution.
Conclusion
DaskML and Ray offer the same features for model selection: stateoftheart features with a ScikitLearn compatible API, and both implementations have fairly wide support for different frameworks and rely on backends that can scale to many machines.
In addition, the Ray implementation has provided motivation for further development, specifically on the following items:
 Adding support for more libraries, including Keras (daskml#696, daskml#713, scikeras#24). SciKeras is a ScikitLearn wrapper for Keras that (now) works with DaskML model selection because SciKeras models implement the ScikitLearn model API.
 Better documenting the models that DaskML supports (daskml#699). DaskML supports any model that implement the ScikitLearn interface, and there are wrappers for Keras, PyTorch, LightGBM and XGBoost. Now, DaskML’s documentation prominently highlights this fact.
The Ray implementation has also helped motivate and clarify future work. DaskML should include the following implementations:
 A Bayesian sampling scheme for the Hyperband implementation that’s similar to Ray’s and BOHB’s (daskml#697).
 A configuration of
HyperbandSearchCV
that’s wellsuited for exploratory hyperparameter searches. An initial implementation is in daskml#532, which should be benchmarked against Ray.
Luckily, all of these pieces of development are straightforward modifications because the DaskML model selection framework is pretty flexible.
Thank you Tom Augspurger, Matthew Rocklin, Julia Signell, and Benjamin Zaitlen for your feedback, suggestions and edits.
Appendix
Benchmark setup
This is the complete setup for the benchmark between DaskML, ScikitLearn and Ray. Complete details can be found at stsievert/daskhyperbandcomparison.
Let’s create a dummy model that takes 1 second for a partial_fit
call with
50,000 examples. This is appropriate for this benchmark; we’re only interested
in the time required to finish the search, not how well the models do.
Scikitlearn, Ray and DaskML have have very similar methods of choosing
hyperparameters to evaluate; they differ in their early stopping techniques.
from scipy.stats import uniform
from sklearn.model_selection import make_classification
from benchmark import ConstantFunction # custom module
# This model sleeps for `latency * len(X)` seconds before
# reporting a score of `value`.
model = ConstantFunction(latency=1 / 50e3, max_iter=max_iter)
params = {"value": uniform(0, 1)}
# This dummy dataset mirrors the MNIST dataset
X_train, y_train = make_classification(n_samples=int(60e3), n_features=784)
This model will take 2 minutes to train for 100 epochs (aka passes through the data). Details can be found at stsievert/daskhyperbandcomparison.
Let’s configure our searches to use 8 workers with a single crossvalidation split:
from sklearn.model_selection import RandomizedSearchCV, ShuffleSplit
split = ShuffleSplit(test_size=0.2, n_splits=1)
kwargs = dict(cv=split, refit=False)
search = RandomizedSearchCV(model, params, n_jobs=8, n_iter=n_params, **kwargs)
search.fit(X_train, y_train) # 20.88 minutes
from dask_ml.model_selection import HyperbandSearchCV
dask_search = HyperbandSearchCV(
model, params, test_size=0.2, max_iter=max_iter, aggressiveness=4
)
from tune_sklearn import TuneSearchCV
ray_search = TuneSearchCV(
model, params, n_iter=n_params, max_iters=max_iter, early_stopping=True, **kwargs
)
dask_search.fit(X_train, y_train) # 2.93 minutes
ray_search.fit(X_train, y_train) # 2.49 minutes
Full example usage
from sklearn.linear_model import SGDClassifier
from scipy.stats import uniform, loguniform
from sklearn.datasets import make_classification
model = SGDClassifier()
params = {"alpha": loguniform(1e5, 1e3), "l1_ratio": uniform(0, 1)}
X, y = make_classification()
from sklearn.model_selection import RandomizedSearchCV
search = RandomizedSearchCV(model, params, ...)
search.fit(X, y)
from dask_ml.model_selection import HyperbandSearchCV
HyperbandSearchCV(model, params, ...)
search.fit(X, y, classes=[0, 1])
from tune_sklearn import TuneSearchCV
search = TuneSearchCV(model, params, ...)
search.fit(X, y, classes=[0, 1])

Their implementation of Hyperband in HpBandSter is included in AutoPyTorch and BOAH. ↩

See Figures 4, 7 and 8 in “Hyperband: A Novel BanditBased Approach to Hyperparameter Optimization.” ↩

See Figure 1 of the BOHB paper and a paper from an augmented reality company. ↩

SMAC is described in “Sequential ModelBased Optimization forGeneral Algorithm Configuration,” and is available in AutoML. ↩

Spearmint is described in “Practical Bayesian Optimization of MachineLearning Algorithms,” and is available in HIPS/spearmint. ↩

TPE is described in Section 4 of “Algorithms for Hyperparameter Optimization,” and is available through Hyperopt. ↩

From Ray’s README.md: “If the estimator does not support
partial_fit
, a warning will be shown saying early stopping cannot be done and it will simply run the crossvalidation on Ray’s parallel backend.” ↩ 
I choose to benchmark random searches instead of grid searches because random searches produce better results because grid searches require estimating how important each parameter is; for more detail see “Random Search for Hyperparameter Optimization” by Bergstra and Bengio. ↩

Despite a relevant implementation in daskml#527. ↩

Because priority is meaningless if there are an infinite number of workers. ↩
blog comments powered by Disqus