Source code for kwcoco.util.util_sklearn

"""
Extensions to sklearn constructs
"""
import warnings
import numpy as np
from sklearn.utils.validation import check_array
from sklearn.model_selection._split import (_BaseKFold,)


[docs]class StratifiedGroupKFold(_BaseKFold): """Stratified K-Folds cross-validator with Grouping Provides train/test indices to split data in train/test sets. This cross-validation object is a variation of GroupKFold that returns stratified folds. The folds are made by preserving the percentage of samples for each class. Read more in the :ref:`User Guide <cross_validation>`. Parameters ---------- n_splits : int, default=3 Number of folds. Must be at least 2. """ def __init__(self, n_splits=3, shuffle=False, random_state=None): if not shuffle: random_state = None super(StratifiedGroupKFold, self).__init__( n_splits=n_splits, shuffle=shuffle, random_state=random_state)
[docs] def _make_test_folds(self, X, y=None, groups=None): """ Args: X (ndarray): data y (ndarray): labels groups (ndarray): groupids for items. Items with the same groupid must be placed in the same group. Returns: list: test_folds Example: >>> import kwarray >>> rng = kwarray.ensure_rng(0) >>> groups = [1, 1, 3, 4, 2, 2, 7, 8, 8] >>> y = [1, 1, 1, 1, 2, 2, 2, 3, 3] >>> X = np.empty((len(y), 0)) >>> self = StratifiedGroupKFold(random_state=rng, shuffle=True) >>> skf_list = list(self.split(X=X, y=y, groups=groups)) ... >>> import ubelt as ub >>> print(ub.repr2(skf_list, nl=1, with_dtype=False)) [ (np.array([2, 3, 4, 5, 6]), np.array([0, 1, 7, 8])), (np.array([0, 1, 2, 7, 8]), np.array([3, 4, 5, 6])), (np.array([0, 1, 3, 4, 5, 6, 7, 8]), np.array([2])), ] """ import kwarray with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'invalid value') n_splits = self.n_splits y = np.asarray(y) n_samples = y.shape[0] unique_y, y_inversed = np.unique(y, return_inverse=True) n_classes = max(unique_y) + 1 unique_groups, group_idxs = kwarray.group_indices(groups) grouped_y = kwarray.apply_grouping(y, group_idxs) grouped_y_counts = np.array([ np.bincount(y_, minlength=n_classes) for y_ in grouped_y]) target_freq = grouped_y_counts.sum(axis=0) target_freq = target_freq.astype(np.float) target_ratio = target_freq / float(target_freq.sum()) # Greedilly choose the split assignment that minimizes the local # * squared differences in target from actual frequencies # * and best equalizes the number of items per fold # Distribute groups with most members first split_freq = np.zeros((n_splits, n_classes)) # split_ratios = split_freq / split_freq.sum(axis=1) split_ratios = np.ones(split_freq.shape) / split_freq.shape[1] split_diffs = ((split_freq - target_ratio) ** 2).sum(axis=1) sortx = np.argsort(grouped_y_counts.sum(axis=1))[::-1] grouped_splitx = [] # import ubelt as ub # print(ub.repr2(grouped_y_counts, nl=-1)) # print('target_ratio = {!r}'.format(target_ratio)) for count, group_idx in enumerate(sortx): # print('---------\n') group_freq = grouped_y_counts[group_idx] cand_freq = (split_freq + group_freq) cand_freq = cand_freq.astype(np.float) cand_ratio = cand_freq / cand_freq.sum(axis=1)[:, None] cand_diffs = ((cand_ratio - target_ratio) ** 2).sum(axis=1) # Compute loss losses = [] # others = np.nan_to_num(split_diffs) other_diffs = np.array([ sum(split_diffs[x + 1:]) + sum(split_diffs[:x]) for x in range(n_splits) ]) # penalize unbalanced splits ratio_loss = other_diffs + cand_diffs # penalize heavy splits freq_loss = split_freq.sum(axis=1) freq_loss = freq_loss.astype(np.float) freq_loss = freq_loss / freq_loss.sum() losses = ratio_loss + freq_loss #------- splitx = np.argmin(losses) # print('losses = %r, splitx=%r' % (losses, splitx)) split_freq[splitx] = cand_freq[splitx] split_ratios[splitx] = cand_ratio[splitx] split_diffs[splitx] = cand_diffs[splitx] grouped_splitx.append(splitx) test_folds = np.empty(n_samples, dtype=int) for group_idx, splitx in zip(sortx, grouped_splitx): idxs = group_idxs[group_idx] test_folds[idxs] = splitx return test_folds
[docs] def _iter_test_masks(self, X, y=None, groups=None): test_folds = self._make_test_folds(X, y, groups) for i in range(self.n_splits): yield test_folds == i
[docs] def split(self, X, y, groups=None): """Generate indices to split data into training and test set. """ y = check_array(y, ensure_2d=False, dtype=None) return super(StratifiedGroupKFold, self).split(X, y, groups)
if __name__ == '__main__': """ CommandLine: python ~/code/kwcoco/kwcoco/util/util_sklearn.py all """ import xdoctest xdoctest.doctest_module(__file__)