"""
from __future__ import annotations
Downloads and converts CIFAR 10 and CIFAR 100 to kwcoco format
"""
import ubelt as ub
import os
import pickle
import kwimage
[docs]
def _convert_cifar_to_kwcoco(dpath, cifar_dset, cifar_name, classes):
import kwcoco
bundle_dpath = (ub.Path(dpath) / cifar_name).ensuredir()
img_dpath = (bundle_dpath / 'images').ensuredir()
CONVERSION_VERSION = 3
stamp = ub.CacheStamp(
'convert_cifar',
dpath=dpath,
depends=[cifar_name, CONVERSION_VERSION],
verbose=3,
)
if stamp.expired():
coco_dset = kwcoco.CocoDataset(bundle_dpath=bundle_dpath)
coco_dset.fpath = bundle_dpath / '{}.kwcoco.json'.format(cifar_name)
for cx, catname in enumerate(classes):
cid = cifar_dset.class_to_idx[catname]
coco_dset.add_category(id=cid, name=catname)
data_label_iter = zip(cifar_dset.data, cifar_dset.targets)
prog = ub.ProgIter(
data_label_iter, total=len(cifar_dset.targets), desc=f'convert {cifar_name}'
)
for gx, (imdata, cidx) in enumerate(prog):
catname = classes[cidx]
name = f'img_{gx:08d}'
subdir = (img_dpath / catname).ensuredir()
fpath = subdir / f'{name}.png'
fname = fpath.relative_to(bundle_dpath)
if not fpath.exists():
kwimage.imwrite(fpath, imdata)
height, width = imdata.shape[0:2]
gid = coco_dset.add_image(
file_name=fname,
id=gx,
name=name,
channels='red|green|blue',
num_overviews=0,
width=width,
height=height,
)
cid = coco_dset.index.name_to_cat[catname]['id']
coco_dset.add_annotation(
image_id=gid, bbox=[0, 0, width, height], category_id=cid
)
print('write coco_dset.fpath = {!r}'.format(coco_dset.fpath))
stamp.renew()
coco_dset.dump(coco_dset.fpath, newlines=True)
else:
fpath = bundle_dpath / f'{cifar_name}.kwcoco.json'
coco_dset = kwcoco.CocoDataset(fpath)
coco_dset.tag = cifar_name
return coco_dset
[docs]
def convert_cifar10(dpath=None):
import torchvision
if dpath is None:
dpath = ub.Path.appdir('kwcoco/data').ensuredir()
else:
dpath = ub.Path(dpath).ensuredir()
# For some reason the torchvision objects dont have the label names
# in the dataset. But the download directory will have them.
expected_classes = [
'airplane',
'automobile',
'bird',
'cat',
'deer',
'dog',
'frog',
'horse',
'ship',
'truck',
]
DATASET = torchvision.datasets.CIFAR10
download_dpath = (dpath / 'download').ensuredir()
cifar_train_dset = DATASET(root=download_dpath, download=True, train=True)
meta_fpath = (
ub.Path(cifar_train_dset.root) / cifar_train_dset.base_folder / 'batches.meta'
)
with open(meta_fpath, 'rb') as file:
meta_dict = pickle.load(file)
classes = meta_dict['label_names']
assert classes == expected_classes
cifar_name = 'cifar10-train'
train_coco_dset = _convert_cifar_to_kwcoco(
dpath, cifar_train_dset, cifar_name, classes
)
cifar_test_dset = DATASET(root=download_dpath, download=True, train=False)
cifar_name = 'cifar10-test'
test_coco_dset = _convert_cifar_to_kwcoco(
dpath, cifar_test_dset, cifar_name, classes
)
return train_coco_dset, test_coco_dset
[docs]
def convert_cifar100(dpath=None):
import torchvision
if dpath is None:
dpath = ub.Path.appdir('kwcoco/data').ensuredir()
expected_classes = [
'apple',
'aquarium_fish',
'baby',
'bear',
'beaver',
'bed',
'bee',
'beetle',
'bicycle',
'bottle',
'bowl',
'boy',
'bridge',
'bus',
'butterfly',
'camel',
'can',
'castle',
'caterpillar',
'cattle',
'chair',
'chimpanzee',
'clock',
'cloud',
'cockroach',
'couch',
'crab',
'crocodile',
'cup',
'dinosaur',
'dolphin',
'elephant',
'flatfish',
'forest',
'fox',
'girl',
'hamster',
'house',
'kangaroo',
'keyboard',
'lamp',
'lawn_mower',
'leopard',
'lion',
'lizard',
'lobster',
'man',
'maple_tree',
'motorcycle',
'mountain',
'mouse',
'mushroom',
'oak_tree',
'orange',
'orchid',
'otter',
'palm_tree',
'pear',
'pickup_truck',
'pine_tree',
'plain',
'plate',
'poppy',
'porcupine',
'possum',
'rabbit',
'raccoon',
'ray',
'road',
'rocket',
'rose',
'sea',
'seal',
'shark',
'shrew',
'skunk',
'skyscraper',
'snail',
'snake',
'spider',
'squirrel',
'streetcar',
'sunflower',
'sweet_pepper',
'table',
'tank',
'telephone',
'television',
'tiger',
'tractor',
'train',
'trout',
'tulip',
'turtle',
'wardrobe',
'whale',
'willow_tree',
'wolf',
'woman',
'worm',
]
cifar_name = 'cifar100-train'
DATASET = torchvision.datasets.CIFAR100
cifar_dset = DATASET(
root=ub.ensuredir((dpath, 'download')), download=True, train=True
)
meta_fpath = os.path.join(cifar_dset.root, cifar_dset.base_folder, 'meta')
meta_dict = pickle.load(open(meta_fpath, 'rb'))
classes = meta_dict['fine_label_names']
assert classes == expected_classes
train_coco_dset = _convert_cifar_to_kwcoco(dpath, cifar_dset, cifar_name, classes)
cifar_name = 'cifar100-test'
DATASET = torchvision.datasets.CIFAR100
cifar_dset = DATASET(
root=ub.ensuredir((dpath, 'download')), download=True, train=False
)
meta_fpath = os.path.join(cifar_dset.root, cifar_dset.base_folder, 'meta')
meta_dict = pickle.load(open(meta_fpath, 'rb'))
classes = meta_dict['fine_label_names']
test_coco_dset = _convert_cifar_to_kwcoco(dpath, cifar_dset, cifar_name, classes)
return [train_coco_dset, test_coco_dset]
[docs]
def main():
import scriptconfig as scfg
class GrabCIFAR_Config(scfg.Config):
"""
Ensure the CIFAR dataset exists in kwcoco format and prints its
location and a bit of info.
"""
__default__ = {
'dpath': scfg.Path(ub.Path.appdir('kwcoco/data'), help='download location'),
'with_10': scfg.Value(True, help='do cifar 10'),
'with_100': scfg.Value(True, help='do cifar 100'),
}
config = GrabCIFAR_Config()
dpath = config['dpath']
items = {}
if config['with_10']:
coco_cifar10 = convert_cifar10(dpath)
items['cifar10'] = coco_cifar10
if config['with_100']:
coco_cifar100 = convert_cifar100(dpath)
items['cifar100'] = coco_cifar100
for key, dsets in items.items():
for dset in dsets:
print('dset = {!r}'.format(dset))
for key, dsets in items.items():
for dset in dsets:
print('{} dset.fpath = {!r}'.format(key, dset.fpath))
return items
if __name__ == '__main__':
"""
CommandLine:
python -m kwcoco.data.grab_cifar
"""
main()