check_cv#

sklearn.model_selection.check_cv(cv=5, y=None, *, classifier=False, shuffle=False, random_state=None)[source]#

Input checker utility for building a cross-validator.

Parameters:
cvint, cross-validation generator, iterable or None, default=5

Determines the cross-validation splitting strategy. Possible inputs for cv are: - None, to use the default 5-fold cross validation, - integer, to specify the number of folds, - CV splitter, - an iterable that generates (train, test) splits as arrays of indices.

For integer/None inputs, if classifier is True and y is either binary or multiclass, StratifiedKFold is used. In all other cases, KFold is used.

Refer User Guide for the various cross-validation strategies that can be used here.

Changed in version 0.22: cv default value changed from 3-fold to 5-fold.

yarray-like, default=None

The target variable for supervised learning problems.

classifierbool, default=False

Whether the task is a classification task. When True and cv is an integer or None, StratifiedKFold is used if y is binary or multiclass; otherwise KFold is used. Ignored if cv is a cross-validator instance or iterable.

shufflebool, default=False

Whether to shuffle the data before splitting into batches. Note that the samples within each split will not be shuffled. Only applies if cv is an int or None. If cv is a cross-validation generator or an iterable, shuffle is ignored.

random_stateint, RandomState instance or None, default=None

When shuffle is True and cv is an integer or None, random_state affects the ordering of the indices, which controls the randomness of each fold. Otherwise, this parameter has no effect. Pass an int for reproducible output across multiple function calls. See Glossary.

Returns:
checked_cva cross-validator instance.

The return value is a cross-validator which generates the train/test splits via the split method.

Examples

>>> from sklearn.model_selection import check_cv
>>> check_cv(cv=5, y=None, classifier=False)
KFold(...)
>>> check_cv(cv=5, y=[1, 1, 0, 0, 0, 0], classifier=True)
StratifiedKFold(...)