Source code for kwcoco.cli.coco_split

#!/usr/bin/env python
import ubelt as ub
import scriptconfig as scfg


[docs] class CocoSplitCLI(object): """ Splits a coco files into two parts base on some criteria. Useful for generating quick and dirty train/test splits, but in general users should opt for using ``kwcoco subset`` instead to explicitly construct these splits based on domain knowledge. """ __command__ = name = 'split'
[docs] class CLIConfig(scfg.Config): """ Split a single COCO dataset into two sub-datasets. """ __default__ = { 'src': scfg.Value(None, help='input dataset to split', position=1), 'dst1': scfg.Value('split1.kwcoco.json', help='output path of the larger split'), 'dst2': scfg.Value('split2.kwcoco.json', help='output path of the smaller split'), 'factor': scfg.Value(3, help='number of items in dset1 for each item in dset2. Also defines the maximum number of splits that could be written.'), 'rng': scfg.Value(None, help='A random seed for reproducible splits'), 'balance_categories': scfg.Value(True, help='if True tries to balance annotation categories across splits'), 'num_write': scfg.Value(1, isflag=True, help=ub.paragraph( ''' The number of splits to write. Can be between 1 and ``factor``. In the case that ``num_write > ``, then dst1 and dst2 datasets must contain a {} format string specifier so each of the output filesnames can be indexed. ''')), 'splitter': scfg.Value( 'auto', help=ub.paragraph( ''' Split method to use. Using "image" will randomly assign each image to a partition. Using "video" will randomly assign each video to a partition. Using "auto" chooses "video" if there are any, otherwise "image". '''), choices=['auto', 'image', 'video']), 'compress': scfg.Value('auto', help='if True writes results with compression'), } __epilog__ = """ Example Usage: kwcoco split --src special:shapes8 --dst1=learn.kwcoco.json --dst2=test.kwcoco.json --factor=3 --rng=42 kwcoco split --src special:shapes8 --dst1="train_{03:d}.kwcoco.json" --dst2="vali_{0:3d}.kwcoco.json" --factor=3 --rng=42 """
[docs] @classmethod def main(cls, cmdline=True, **kw): """ Example: >>> from kwcoco.cli.coco_split import * # NOQA >>> import ubelt as ub >>> dpath = ub.Path.appdir('kwcoco/tests/cli/split').ensuredir() >>> kw = {'src': 'special:vidshapes8', >>> 'dst1': dpath / 'train.json', >>> 'dst2': dpath / 'test.json'} >>> cmdline = False >>> cls = CocoSplitCLI >>> cls.main(cmdline, **kw) """ import kwcoco import kwarray from kwcoco.util import util_sklearn config = cls.CLIConfig(kw, cmdline=cmdline) print('config = {}'.format(ub.urepr(dict(config), nl=1))) if config['src'] is None: raise Exception('must specify source: {}'.format(config['src'])) if config['num_write'] > 1: if not set(str(config['dst1'])).issuperset(set('{}')): raise Exception( 'when num_write is True dst1 and dst2 must contain a {} format string placeholder') if not set(str(config['dst2'])).issuperset(set('{}')): raise Exception( 'when num_write is True dst1 and dst2 must contain a {} format string placeholder') print('reading fpath = {!r}'.format(config['src'])) dset = kwcoco.CocoDataset.coerce(config['src']) splitter = config['splitter'] if splitter == 'auto': splitter = 'video' if dset.n_videos > 0 else 'image' images = dset.images() cids_per_image = images.annots.cids gids = images.lookup('id') if splitter == 'video': group_ids = images.lookup('video_id') elif splitter == 'image': group_ids = gids else: raise KeyError(splitter) final_group_ids = [] final_group_gids = [] final_group_cids = [] unique_cids = set(ub.flatten(cids_per_image)) | {0} distinct_cid = max(unique_cids) + 11 for group_id, gid, cids in zip(group_ids, gids, cids_per_image): if len(cids) == 0: final_group_ids.append(group_id) final_group_gids.append(gid) final_group_cids.append(distinct_cid) else: final_group_ids.extend([group_id] * len(cids)) final_group_gids.extend([gid] * len(cids)) final_group_cids.extend(cids) # Balanced category split rng = kwarray.ensure_rng(config['rng']) shuffle = rng is not None factor = config['factor'] self = util_sklearn.StratifiedGroupKFold(n_splits=factor, random_state=rng, shuffle=shuffle) if config['balance_categories']: split_idxs = list(self.split(X=final_group_gids, y=final_group_cids, groups=final_group_ids)) else: split_idxs = list(self.split(X=final_group_gids, y=final_group_gids, groups=final_group_ids)) dumpkw = { 'newlines': True, 'compress': config['compress'], } for split_num, (idxs1, idxs2) in enumerate(split_idxs): print(f'Build split {split_num} / {factor} with ratio {len(idxs1)}:{len(idxs2)}') idxs1, idxs2 = split_idxs[0] gids1 = sorted(ub.unique(ub.take(final_group_gids, idxs1))) gids2 = sorted(ub.unique(ub.take(final_group_gids, idxs2))) dset1 = dset.subset(gids1) dset2 = dset.subset(gids2) print('stats(dset1): ' + ub.urepr(dset1.basic_stats(), nl=0)) print('stats(dset2): ' + ub.urepr(dset2.basic_stats(), nl=0)) dset1.fpath = str(config['dst1']).format(split_num) dset2.fpath = str(config['dst2']).format(split_num) print(f'Writing dset1({split_num} / {factor}) = {dset1.fpath!r}') dset1.dump(**dumpkw) print(f'Writing dset2({split_num} / {factor}) = {dset2.fpath!r}') dset2.dump(**dumpkw) if split_num + 1 >= config['num_write']: break
_CLI = CocoSplitCLI if __name__ == '__main__': _CLI.main()