kwcoco.util.util_sklearn
¶
Extensions to sklearn constructs
Module Contents¶
Classes¶
Stratified K-Folds cross-validator with Grouping |
- class kwcoco.util.util_sklearn.StratifiedGroupKFold(n_splits=3, shuffle=False, random_state=None)[source]¶
Bases:
sklearn.model_selection._split._BaseKFold
Stratified K-Folds cross-validator with Grouping
Provides train/test indices to split data in train/test sets.
This cross-validation object is a variation of GroupKFold that returns stratified folds. The folds are made by preserving the percentage of samples for each class.
Read more in the User Guide.
- Parameters
n_splits (int, default=3) – Number of folds. Must be at least 2.
- _make_test_folds(self, X, y=None, groups=None)[source]¶
- Parameters
X (ndarray) – data
y (ndarray) – labels
groups (ndarray) – groupids for items. Items with the same groupid must be placed in the same group.
- Returns
test_folds
- Return type
Example
>>> import kwarray >>> rng = kwarray.ensure_rng(0) >>> groups = [1, 1, 3, 4, 2, 2, 7, 8, 8] >>> y = [1, 1, 1, 1, 2, 2, 2, 3, 3] >>> X = np.empty((len(y), 0)) >>> self = StratifiedGroupKFold(random_state=rng, shuffle=True) >>> skf_list = list(self.split(X=X, y=y, groups=groups)) ... >>> import ubelt as ub >>> print(ub.repr2(skf_list, nl=1, with_dtype=False)) [ (np.array([2, 3, 4, 5, 6]), np.array([0, 1, 7, 8])), (np.array([0, 1, 2, 7, 8]), np.array([3, 4, 5, 6])), (np.array([0, 1, 3, 4, 5, 6, 7, 8]), np.array([2])), ]