This documentation is for scikit-learn version 0.11-gitOther versions

Citing

If you use the software, please consider citing scikit-learn.

This page

3. Model selection: choosing estimators and their parameters

3.1. Score, and cross-validated scores

As we have seen, every estimator exposes a score method that can judge the quality of the fit (or the prediction) on new data. Bigger is better.

>>> from scikits.learn import datasets, svm
>>> digits = datasets.load_digits()
>>> X_digits = digits.data
>>> y_digits = digits.target
>>> svc = svm.SVC()
>>> svc.fit(X_digits[:-100], y_digits[:-100]).score(X_digits[-100:], y_digits[-100:])
0.41999999999999998

To get a better measure of prediction accuracy (which we can use as a proxy for goodness of fit of the model), we can successively split the data in folds that we use for training and testing:

>>> import numpy as np
>>> X_folds = np.array_split(X_digits, 3)
>>> y_folds = np.array_split(y_digits, 3)
>>> scores = list()
>>> for k in range(3):
...     # We use 'list' to copy, in order to 'pop' later on
...     X_train = list(X_folds)
...     X_test  = X_train.pop(k)
...     X_train = np.concatenate(X_train)
...     y_train = list(y_folds)
...     y_test  = y_train.pop(k)
...     y_train = np.concatenate(y_train)
...     scores.append(svc.fit(X_train, y_train).score(X_test, y_test))
>>> print scores
[0.41068447412353926, 0.41569282136894825, 0.42737896494156929]

This is called a K-Fold cross-validation.

3.2. Cross-validation generators

The code above to split data in train and test sets is tedious to write. The scikits.learn exposes cross-validation generators to generate list of indices for this purpose:

>>> from scikits.learn import cross_val
>>> k_fold = cross_val.KFold(n=6, k=3, indices=True)
>>> for train_indices, test_indices in k_fold:
...      print 'Train: %s | test: %s' % (train_indices, test_indices)
Train: [2 3 4 5] | test: [0 1]
Train: [0 1 4 5] | test: [2 3]
Train: [0 1 2 3] | test: [4 5]

The cross-validation can then be implemented easily:

>>> kfold = cross_val.KFold(len(X_digits), k=3)
>>> [svc.fit(X_digits[train], y_digits[train]).score(X_digits[test], y_digits[test])
...          for train, test in kfold]
[0.41068447412353926, 0.41569282136894825, 0.42737896494156929]

To compute the score method of an estimator, the scikits.learn exposes a helper function:

>>> cross_val.cross_val_score(svc, X_digits, y_digits, cv=kfold, n_jobs=-1)
array([ 0.41068447,  0.41569282,  0.42737896])

n_jobs=-1 means that the computation will be dispatched on all the CPUs of the computer.

Cross-validation generators
KFold(n, k) StratifiedKFold(y, k) LeaveOneOut(n) LeaveOneLabelOut(labels)
Split it K folds, train on K-1, test on left-out Make sure that all classes are even accross the folds Leave one observation out Takes a label array to group observations
tutorial/statistical_inference/cv_digits.png

Exercise

On the digits dataset, plot the cross-validation score of a SVC estimator with an RBF kernel as a function of gamma (use a logarithmic grid of points, from 1e-6 to 1e-1).

3.3. Grid-search and cross-validated estimators

3.3.2. Cross-validated estimators

Cross-validation to set a parameter can be done more efficiently on an algorithm-by-algorithm basis. This is why, for certain estimators, the scikits.learn exposes “CV” estimators, that set their parameter automatically by cross-validation:

>>> from scikits.learn import linear_model, datasets
>>> lasso = linear_model.LassoCV()
>>> diabetes = datasets.load_diabetes()
>>> X_diabetes = diabetes.data
>>> y_diabetes = diabetes.target
>>> lasso.fit(X_diabetes, y_diabetes)
LassoCV(alphas=array([ 2.14804,  2.00327, ...,  0.0023 ,  0.00215]),
    copy_X=True, cv=None, eps=0.001, fit_intercept=True, max_iter=1000,
    n_alphas=100, normalize=False, precompute='auto', tol=0.0001,
    verbose=False)
>>> # The estimator chose automatically its lambda:
>>> lasso.alpha
0.013180196198701137

These estimators are called similarly to their counterparts, with ‘CV’ appended to their name.

Exercise

On the diabetes dataset, find the optimal regularization parameter alpha.

Bonus: How much can you trust the selection of alpha?