Source code for kwcoco.util.lazy_frame_backends

Ducktyped interfaces for loading subregions of images with standard slice
import os
import ubelt as ub
import numpy as np
import kwimage
from os.path import join, exists

    import xdev
[docs] profile = xdev.profile
except Exception: profile = ub.identity @ub.memoize
[docs]def _have_gdal(): try: from osgeo import gdal except ImportError: return False else: return gdal is not None
[docs]def _have_rasterio(): try: import rasterio except ImportError: return False else: return rasterio is not None
[docs]def _have_spectral(): try: import spectral except ImportError: return False else: return spectral is not None
[docs]_GDAL_DTYPE_LUT = { 1: np.uint8, 2: np.uint16, 3: np.int16, 4: np.uint32, 5: np.int32, 6: np.float32, 7: np.float64, 8: np.complex_, 9: np.complex_, 10: np.complex64, 11: np.complex128
[docs]class LazySpectralFrameFile(ub.NiceRepr): """ Potentially faster than GDAL for HDR formats. """ def __init__(self, fpath): self.fpath = fpath @ub.memoize_property
[docs] def _ds(self): import spectral from os.path import exists if not exists(self.fpath): raise Exception('File does not exist: {}'.format(self.fpath)) ds = return ds
[docs] def available(self): """ Returns True if this backend is available """ return _have_spectral()
[docs] def ndim(self): return len(self.shape)
[docs] def shape(self): return self._ds.shape
[docs] def dtype(self): return self._ds.dtype
[docs] def __nice__(self): from os.path import basename return '.../' + basename(self.fpath)
[docs] def __getitem__(self, index): ds = self._ds height, width, C = ds.shape if not ub.iterable(index): index = [index] index = list(index) if len(index) < 3: n = (3 - len(index)) index = index + [None] * n ypart = _rectify_slice_dim(index[0], height) xpart = _rectify_slice_dim(index[1], width) channel_part = _rectify_slice_dim(index[2], C) trailing_part = [channel_part] if len(trailing_part) == 1: channel_part = trailing_part[0] if isinstance(channel_part, list): band_indices = channel_part else: band_indices = range(*channel_part.indices(C)) else: band_indices = range(C) assert len(trailing_part) <= 1 ystart, ystop = map(int, [ypart.start, ypart.stop]) xstart, xstop = map(int, [xpart.start, xpart.stop]) img_part = ds.read_subregion( row_bounds=(ystart, ystop), col_bounds=(xstart, xstop), bands=band_indices) return img_part
[docs]class LazyRasterIOFrameFile(ub.NiceRepr): """ fpath = '/home/joncrall/.cache/kwcoco/demo/large_hyperspectral/big_img_128.bsq' lazy_rio = LazyRasterIOFrameFile(fpath) ds = lazy_rio._ds Ignore: # Can rasterio read multiple bands at once? # Seems like that is an overhead for hyperspectral images import rasterio riods = import timerit ti = timerit.Timerit(1, bestof=1, verbose=2) b = tuple(range(1, riods.count + 1)) for timer in ti.reset('rasterio'): with timer: lazy_rio = LazyRasterIOFrameFile(fpath) for timer in ti.reset('LazyRasterIOFrameFile'): with timer: lazy_rio[:] lazy_gdal = LazyGDalFrameFile(fpath) for timer in ti.reset('LazyGDalFrameFile'): with timer: lazy_gdal[:] """ def __init__(self, fpath): self.fpath = fpath @classmethod
[docs] def available(self): """ Returns True if this backend is available """ return _have_rasterio()
[docs] def _ds(self): import rasterio from os.path import exists if not exists(self.fpath): raise Exception('File does not exist: {}'.format(self.fpath)) ds =, mode='r') return ds
[docs] def ndim(self): return len(self.shape)
[docs] def shape(self): ds = self._ds return (ds.height, ds.width, ds.count)
[docs] def dtype(self): # Assume the first is the same as the rest ds = self._ds dtype = getattr(np, ds.dtypes[0]) return dtype
[docs] def __nice__(self): from os.path import basename return '.../' + basename(self.fpath)
[docs] def __getitem__(self, index): ds = self._ds width = ds.width height = ds.height C = ds.count if not ub.iterable(index): index = [index] index = list(index) if len(index) < 3: n = (3 - len(index)) index = index + [None] * n ypart = _rectify_slice_dim(index[0], height) xpart = _rectify_slice_dim(index[1], width) channel_part = _rectify_slice_dim(index[2], C) trailing_part = [channel_part] if len(trailing_part) == 1: channel_part = trailing_part[0] if isinstance(channel_part, list): band_indices = channel_part else: band_indices = range(*channel_part.indices(C)) else: band_indices = range(C) assert len(trailing_part) <= 1 ystart, ystop = map(int, [ypart.start, ypart.stop]) xstart, xstop = map(int, [xpart.start, xpart.stop]) indexes = [b + 1 for b in band_indices] img_part =, window=((ystart, ystop), (xstart, xstop))) img_part = img_part.transpose(1, 2, 0) return img_part
[docs]def _demo_geoimg_with_nodata(): """ Example: from kwcoco.util.lazy_frame_backends import * # NOQA fpath = _demo_geoimg_with_nodata() self = LazyGDalFrameFile.demo() """ import kwimage from osgeo import osr # gdal.UseExceptions() # Make a dummy geotiff imdata = kwimage.grab_test_image('airport') dpath = ub.Path.appdir('kwcoco/test/geotiff').ensuredir() geo_fpath = dpath / 'dummy_geotiff.tif' # compute dummy values for a geotransform to CRS84 img_h, img_w = imdata.shape[0:2] img_box = kwimage.Boxes([[0, 0, img_w, img_h]], 'xywh') wld_box = kwimage.Boxes([[-73.7595528, 42.6552404, 0.0001, 0.0001]], 'xywh') img_corners = img_box.corners() wld_corners = wld_box.corners() transform =, wld_corners) nodata = -9999 srs = osr.SpatialReference() srs.ImportFromEPSG(4326) crs = srs.ExportToWkt() # Set a region to be nodata imdata = imdata.astype(np.int16) imdata[-100:] = nodata imdata[0:200:, -200:-180] = nodata kwimage.imwrite(geo_fpath, imdata, backend='gdal', nodata=-9999, crs=crs, transform=transform) return geo_fpath
[docs]class LazyGDalFrameFile(ub.NiceRepr): """ TODO: - [ ] Move to its own backend module - [ ] When used with COCO, allow the image metadata to populate the height, width, and channels if possible. Example: >>> # xdoctest: +REQUIRES(module:osgeo) >>> self = LazyGDalFrameFile.demo() >>> print('self = {!r}'.format(self)) >>> self[0:3, 0:3] >>> self[:, :, 0] >>> self[0] >>> self[0, 3] >>> # import kwplot >>> # kwplot.imshow(self[:]) Args: nodata masking_method Example: >>> # See if we can reproduce the INTERLEAVE bug data = np.random.rand(128, 128, 64) import kwimage import ubelt as ub from os.path import join dpath = ub.ensure_app_cache_dir('kwcoco/tests/reader') fpath = join(dpath, 'foo.tiff') kwimage.imwrite(fpath, data, backend='skimage') recon1 = kwimage.imread(fpath) recon1.shape self = LazyGDalFrameFile(fpath) self.shape self[:] """ def __init__(self, fpath, nodata=None): self.fpath = fpath self.nodata = nodata if nodata == 'auto': self.masking_method = 'float' else: self.masking_method = nodata @classmethod
[docs] def available(self): """ Returns True if this backend is available """ return _have_gdal()
[docs] def _ds(self): from osgeo import gdal if not exists(self.fpath): raise Exception('File does not exist: {}'.format(self.fpath)) _fpath = os.fspath(self.fpath) if _fpath.endswith('.hdr'): # Use spectral-like process to point gdal to the correct file given # the hdr ext = '.' + _read_envi_header(_fpath)['interleave'] _fpath = ub.augpath(_fpath, ext=ext) ds = gdal.Open(_fpath, gdal.GA_ReadOnly) if ds is None: raise Exception(( 'GDAL Failed to open the fpath={!r} for an unknown reason. ' 'Call gdal.UseExceptions() beforehand to get the ' 'real exception').format(self.fpath)) return ds
[docs] def demo(cls, key='astro', dsize=None): """ Ignore: >>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400)) """ cache_dpath = ub.ensure_app_cache_dir('kwcoco/demo') fpath = join(cache_dpath, key + '.cog.tiff') depends = ub.odict(dsize=dsize) stamp = ub.CacheStamp(fname=key, depends=depends, dpath=cache_dpath, product=[fpath]) if stamp.expired(): img = kwimage.grab_test_image(key, dsize=dsize) kwimage.imwrite(fpath, img, backend='gdal') stamp.renew() self = cls(fpath) return self
[docs] def ndim(self): return len(self.shape)
[docs] def shape(self): # if 0: # ds = self.ds # INTERLEAVE = ds.GetMetadata('IMAGE_STRUCTURE').get('INTERLEAVE', '') # handle INTERLEAVE=BAND # if INTERLEAVE == 'BAND': # pass # ds.GetMetadata('') # handle TIFFTAG_IMAGEDESCRIPTION # from osgeo import gdal # subdataset_infos = ds.GetSubDatasets() # subdatasets = [] # for subinfo in subdataset_infos: # path = subinfo[0] # sub_ds = gdal.Open(path, gdal.GA_ReadOnly) # subdatasets.append(sub_ds) # for sub in subdatasets: # sub.ReadAsArray() # print((sub.RasterXSize, sub.RasterYSize, sub.RasterCount)) # sub = subdatasets[0][0] ds = self._ds width = ds.RasterXSize height = ds.RasterYSize C = ds.RasterCount return (height, width, C)
[docs] def dtype(self): main_band = self._ds.GetRasterBand(1) dtype = _GDAL_DTYPE_LUT[main_band.DataType] return dtype
[docs] def __nice__(self): from os.path import basename return '.../' + basename(self.fpath)
[docs] def __getitem__(self, index): """ References: Ignore: >>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400)) >>> index = [slice(2100, 2508, None), slice(4916, 5324, None), None] >>> img_part = self[index] >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(img_part) >>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400)) >>> self.nodata = 0 >>> index = [slice(2100, 2508, None), slice(4916, 5324, None), None] >>> img_part = self[index] >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.imshow(img_part / 255) Example: >>> # Test nodata works correctly >>> # xdoctest: +REQUIRES(module:osgeo) >>> from kwcoco.util.lazy_frame_backends import * # NOQA >>> from kwcoco.util.lazy_frame_backends import _demo_geoimg_with_nodata >>> fpath = _demo_geoimg_with_nodata() >>> self = LazyGDalFrameFile(fpath, nodata='auto') >>> imdata = self[:] >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> import kwarray >>> kwplot.autompl() >>> imdata = kwimage.normalize_intensity(imdata) >>> imdata = np.nan_to_num(imdata) >>> kwplot.imshow(imdata) """ ds = self._ds width = ds.RasterXSize height = ds.RasterYSize C = ds.RasterCount if 1: INTERLEAVE = ds.GetMetadata('IMAGE_STRUCTURE').get('INTERLEAVE', '') if INTERLEAVE == 'BAND': if len(ds.GetSubDatasets()) > 0: raise NotImplementedError('Cannot handle interleaved files yet') if not ub.iterable(index): index = [index] index = list(index) if len(index) < 3: n = (3 - len(index)) index = index + [None] * n ypart = _rectify_slice_dim(index[0], height) xpart = _rectify_slice_dim(index[1], width) channel_part = _rectify_slice_dim(index[2], C) trailing_part = [channel_part] if len(trailing_part) == 1: channel_part = trailing_part[0] if isinstance(channel_part, list): band_indices = channel_part else: band_indices = range(*channel_part.indices(C)) else: band_indices = range(C) assert len(trailing_part) <= 1 ystart, ystop = map(int, [ypart.start, ypart.stop]) xstart, xstop = map(int, [xpart.start, xpart.stop]) ysize = ystop - ystart xsize = xstop - xstart gdalkw = dict(xoff=xstart, yoff=ystart, win_xsize=xsize, win_ysize=ysize) nodata = self.nodata needs_nodata = nodata is not None auto_nodata = nodata == 'auto' read_nodata = isinstance(nodata, str) shape = (ysize, xsize, len(band_indices)) # mask_shape = (ysize, xsize, len(band_indices)) mask_shape = (ysize, xsize,) if needs_nodata: # TODO: can we remove the band dimension here? mask = np.zeros(mask_shape, dtype=bool) # preallocate like kwimage.im_io._imread_gdal from kwimage.im_io import _gdal_to_numpy_dtype bands = [ds.GetRasterBand(1 + band_idx) for band_idx in band_indices] gdal_dtype = bands[0].DataType dtype = _gdal_to_numpy_dtype(gdal_dtype) try: img_part = np.empty(shape, dtype=dtype) except ValueError: print('ERROR') print('self.fpath = {!r}'.format(self.fpath)) print('dtype = {!r}'.format(dtype)) print('shape = {!r}'.format(shape)) raise for out_idx, band in enumerate(bands): buf = band.ReadAsArray(**gdalkw) if buf is None: raise IOError(ub.paragraph( ''' GDAL was unable to read band: {}, {}, with={} from fpath={!r} '''.format(out_idx, band, gdalkw, self.fpath))) # print('auto_nodata = {!r}'.format(auto_nodata)) if read_nodata: _nodata = band.GetNoDataValue() else: _nodata = nodata # print('nodata = {!r}'.format(_nodata)) if _nodata is not None: mask |= (buf == _nodata) # mask[:, :, out_idx] = (buf == _nodata) img_part[:, :, out_idx] = buf buf = None # masking_method = self.masking_method == 'auto' # if masking_method is None: # pass # elif masking_method if auto_nodata: needs_nodata = mask.any() if needs_nodata: if self.masking_method == 'float': # print('float mask') # Hack it so nodata becomes nan masked_hack_dtype = np.result_type(img_part.dtype, np.float32) img_part = img_part.astype(masked_hack_dtype) img_part[np.where(mask)] = np.nan imdata = img_part elif self.masking_method == 'ma': # Using a regular masked array might be better mask3 = np.dstack([mask] * C) imdata =, mask=mask3, fill_value=None) else: raise NotImplementedError(self.masking_method) else: imdata = img_part return imdata
[docs] def __array__(self): """ Allow this object to be passed to np.asarray References: """ return self[:]
[docs]def _rectify_slice_dim(part, D): if part is None: return slice(0, D) elif isinstance(part, slice): start = 0 if part.start is None else max(0, part.start) stop = D if part.stop is None else min(D, part.stop) if stop < 0: stop = D + stop assert part.step is None part = slice(start, stop) return part elif isinstance(part, int): part = slice(part, part + 1) elif isinstance(part, list): part = part else: raise TypeError(part) return part
[docs]def _validate_nonzero_data(file): """ Test to see if the image is all black. May fail on all-black images Example: >>> # xdoctest: +REQUIRES(module:osgeo) >>> import kwimage >>> gpath = kwimage.grab_test_image_fpath() >>> file = LazyGDalFrameFile(gpath) >>> _validate_nonzero_data(file) """ try: import numpy as np # Find center point of the image cx, cy = np.array(file.shape[0:2]) // 2 center = [cx, cy] # Check if the center pixels have data, look at more data if needbe sizes = [8, 512, 2048, 5000] for d in sizes: index = tuple(slice(c - d, c + d) for c in center) partial_data = file[index] total = partial_data.sum() if total > 0: break if total == 0: total = file[:].sum() has_data = total > 0 except Exception: has_data = False return has_data
[docs]def _read_envi_header(file): """ USAGE: hdr = _read_envi_header(file) Reads an ENVI ".hdr" file header and returns the parameters in a dictionary as strings. Header field names are treated as case insensitive and all keys in the dictionary are lowercase. Modified from spectral/io/ References: """ f = open(file, 'r') try: starts_with_ENVI = f.readline().strip().startswith('ENVI') except UnicodeDecodeError: msg = ( 'File does not appear to be an ENVI header (appears to be a ' 'binary file).') f.close() raise Exception(msg) else: if not starts_with_ENVI: msg = ('File does not appear to be an ENVI header (missing "ENVI" ' 'at beginning of first line).') f.close() raise Exception(msg) lines = f.readlines() f.close() dict = {} have_nonlowercase_param = False support_nonlowercase_params = False try: while lines: line = lines.pop(0) if line.find('=') == -1: continue if line[0] == ';': continue (key, sep, val) = line.partition('=') key = key.strip() if not key.islower(): have_nonlowercase_param = True if not support_nonlowercase_params: key = key.lower() val = val.strip() if val and val[0] == '{': str = val.strip() while str[-1] != '}': line = lines.pop(0) if line[0] == ';': continue str += '\n' + line.strip() if key == 'description': dict[key] = str.strip('{}').strip() else: vals = str[1:-1].split(',') for j in range(len(vals)): vals[j] = vals[j].strip() dict[key] = vals else: dict[key] = val if have_nonlowercase_param and not support_nonlowercase_params: import warnings msg = 'Parameters with non-lowercase names encountered ' \ 'and converted to lowercase. To retain source file ' \ 'parameter name capitalization, set ' \ 'spectral.settings.envi_support_nonlowercase_params to ' \ 'True.' warnings.warn(msg) return dict except Exception: raise