Source code for kwcoco.util.util_sklearn

"""
Extensions to sklearn constructs
"""
import warnings
import numpy as np
from sklearn.utils.validation import check_array
from sklearn.model_selection._split import (_BaseKFold,)
from packaging.version import parse as Version

ARGSORT_HAS_STABLE_KIND = Version(np.__version__) >= Version('1.15.0')


[docs] class StratifiedGroupKFold(_BaseKFold): """Stratified K-Folds cross-validator with Grouping Provides train/test indices to split data in train/test sets. This cross-validation object is a variation of GroupKFold that returns stratified folds. The folds are made by preserving the percentage of samples for each class. This is an old interface and should likely be refactored and modernized. Parameters ---------- n_splits : int, default=3 Number of folds. Must be at least 2. """ def __init__(self, n_splits=3, shuffle=False, random_state=None): if not shuffle: random_state = None super(StratifiedGroupKFold, self).__init__( n_splits=n_splits, shuffle=shuffle, random_state=random_state)
[docs] def _make_test_folds(self, X, y=None, groups=None): """ Args: X (ndarray): data y (ndarray): labels groups (ndarray): groupids for items. Items with the same groupid must be placed in the same group. Returns: list: test_folds Example: >>> from kwcoco.util.util_sklearn import * # NOQA >>> import kwarray >>> rng = kwarray.ensure_rng(0) >>> groups = [1, 1, 3, 4, 2, 2, 7, 8, 8] >>> y = [1, 1, 1, 1, 2, 2, 2, 3, 3] >>> X = np.empty((len(y), 0)) >>> self = StratifiedGroupKFold(random_state=rng, shuffle=True) >>> skf_list = list(self.split(X=X, y=y, groups=groups)) >>> import ubelt as ub >>> print(ub.urepr(skf_list, nl=1, with_dtype=False)) [ (np.array([2, 3, 4, 5, 6]), np.array([0, 1, 7, 8])), (np.array([0, 1, 2, 7, 8]), np.array([3, 4, 5, 6])), (np.array([0, 1, 3, 4, 5, 6, 7, 8]), np.array([2])), ] """ import kwarray with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'invalid value') n_splits = self.n_splits y = np.asarray(y) n_samples = y.shape[0] unique_y, y_inversed = np.unique(y, return_inverse=True) n_classes = max(unique_y) + 1 unique_groups, group_idxs = kwarray.group_indices(groups) if ARGSORT_HAS_STABLE_KIND: # Fixed in latest kwarray, but this doesn't hurt. for gxs in group_idxs: gxs.sort() grouped_y = kwarray.apply_grouping(y, group_idxs) grouped_y_counts = np.array([ np.bincount(y_, minlength=n_classes) for y_ in grouped_y]) target_freq = grouped_y_counts.sum(axis=0) target_freq = target_freq.astype(float) target_ratio = target_freq / float(target_freq.sum()) # Greedily choose the split assignment that minimizes the local # * squared differences in target from actual frequencies # * and best equalizes the number of items per fold # Distribute groups with most members first split_freq = np.zeros((n_splits, n_classes)) # split_ratios = split_freq / split_freq.sum(axis=1) split_ratios = np.ones(split_freq.shape) / split_freq.shape[1] split_diffs = ((split_freq - target_ratio) ** 2).sum(axis=1) rowsum = grouped_y_counts.sum(axis=1) if ARGSORT_HAS_STABLE_KIND: sortx = np.argsort(rowsum, kind='stable')[::-1] else: sortx = np.argsort(rowsum)[::-1] grouped_splitx = [] # import ubelt as ub # print(ub.urepr(grouped_y_counts, nl=-1)) # print('target_ratio = {!r}'.format(target_ratio)) for count, group_idx in enumerate(sortx): # print('---------\n') group_freq = grouped_y_counts[group_idx] cand_freq = (split_freq + group_freq) cand_freq = cand_freq.astype(float) cand_ratio = cand_freq / cand_freq.sum(axis=1)[:, None] cand_diffs = ((cand_ratio - target_ratio) ** 2).sum(axis=1) # Compute loss losses = [] # others = np.nan_to_num(split_diffs) other_diffs = np.array([ sum(split_diffs[x + 1:]) + sum(split_diffs[:x]) for x in range(n_splits) ]) # penalize unbalanced splits ratio_loss = other_diffs + cand_diffs # penalize heavy splits freq_loss = split_freq.sum(axis=1) freq_loss = freq_loss.astype(float) freq_loss = freq_loss / freq_loss.sum() losses = ratio_loss + freq_loss #------- splitx = np.argmin(losses) # print('losses = %r, splitx=%r' % (losses, splitx)) split_freq[splitx] = cand_freq[splitx] split_ratios[splitx] = cand_ratio[splitx] split_diffs[splitx] = cand_diffs[splitx] grouped_splitx.append(splitx) test_folds = np.empty(n_samples, dtype=int) for group_idx, splitx in zip(sortx, grouped_splitx): idxs = group_idxs[group_idx] test_folds[idxs] = splitx return test_folds
[docs] def _iter_test_masks(self, X, y=None, groups=None): test_folds = self._make_test_folds(X, y, groups) for i in range(self.n_splits): yield test_folds == i
[docs] def split(self, X, y, groups=None): """Generate indices to split data into training and test set. """ y = check_array(y, ensure_2d=False, dtype=None) return super(StratifiedGroupKFold, self).split(X, y, groups)