5.2. Grid Search¶
Grid Search is used to optimize the parameters of a model (e.g. C, kernel and gamma for Support Vector Classifier, alpha for Lasso, etc.) using an internal Cross-Validation scheme).
5.2.1. GridSearchCV¶
The main class for implementing hyperparameters grid search in scikit-learn is grid_search.GridSearchCV. This class is passed a base model instance (for example sklearn.svm.SVC()) along with a grid of potential hyper-parameter values such as:
[{'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']},
{'C': [1, 10, 100, 1000], 'kernel': ['linear']}]
The grid_search.GridSearchCV instance implements the usual estimator API: when “fitting” it on a dataset all the possible combinations of hyperparameter values are evaluated and the best combinations is retained.
Model selection: development and evaluation
Model selection with GridSearchCV can be seen as a way to use the labeled data to “train” the hyper-parameters of the grid.
When evaluating the resulting model it is important to do it on held-out samples that were not seen during the grid search process: it is recommended to split the data into a development set (to be fed to the GridSearchCV instance) and an evaluation set to compute performance metrics.
This can be done by using the cross_validation.train_test_split utility function.
5.2.2. Examples¶
- See Parameter estimation using grid search with a nested cross-validation for an example of Grid Search computation on the digits dataset.
- See Sample pipeline for text feature extraction and evaluation for an example of Grid Search coupling parameters from a text documents feature extractor (n-gram count vectorizer and TF-IDF transformer) with a classifier (here a linear SVM trained with SGD with either elastic net or L2 penalty) using a pipeline.Pipeline instance.
Note
Computations can be run in parallel if your OS supports it, by using the keyword n_jobs=-1, see function signature for more details.
5.2.3. Alternatives to brute force grid search¶
5.2.3.1. Model specific cross-validation¶
Some models can fit data for a range of value of some parameter almost as efficiently as fitting the estimator for a single value of the parameter. This feature can be leveraged to perform a more efficient cross-validation used for model selection of this parameter.
The most common parameter amenable to this strategy is the parameter encoding the strength of the regularizer. In this case we say that we compute the regularization path of the estimator.
Here is the list of such models:
linear_model.RidgeCV([alphas, ...]) | Ridge regression with built-in cross-validation. |
linear_model.RidgeClassifierCV([alphas, ...]) | |
linear_model.LarsCV([fit_intercept, ...]) | Cross-validated Least Angle Regression model |
linear_model.LassoLarsCV([fit_intercept, ...]) | Cross-validated Lasso, using the LARS algorithm |
linear_model.LassoCV([eps, n_alphas, ...]) | Lasso linear model with iterative fitting along a regularization path |
linear_model.ElasticNetCV([rho, eps, ...]) | Elastic Net model with iterative fitting along a regularization path |
5.2.3.2. Information Criterion¶
Some models can offer an information-theoretic closed-form formula of the optimal estimate of the regularization parameter by computing a single regularization path (instead of several when using cross-validation).
Here is the list of models benefitting from the Aikike Information Criterion (AIC) or the Bayesian Information Criterion (BIC) for automated model selection:
linear_model.LassoLarsIC([criterion, ...]) | Lasso model fit with Lars using BIC or AIC for model selection |
5.2.3.3. Out of Bag Estimates¶
When using ensemble methods base upon bagging, i.e. generating new training sets using sampling with replacement, part of the training set remains unused. For each classifier in the ensemble, a different part of the training set is left out.
This left out portion can be used to estimate the generalization error without having to rely on a separate validation set. This estimate comes “for free” as no addictional data is needed and can be used for model selection.
This is currently implemented in the following classes:
ensemble.RandomForestClassifier([...]) | A random forest classifier. |
ensemble.RandomForestRegressor([...]) | A random forest regressor. |
ensemble.ExtraTreesClassifier([...]) | An extra-trees classifier. |
ensemble.ExtraTreesRegressor([n_estimators, ...]) | An extra-trees regressor. |