"""
Ducktyped interfaces for loading subregions of images with standard slice
syntax
"""
import os
import ubelt as ub
import numpy as np
import kwimage
from os.path import join, exists
from collections import OrderedDict
try:
import xdev
except Exception:
profile = ub.identity
[docs]class CacheDict(OrderedDict):
"""
Dict with a limited length, ejecting LRUs as needed.
Example:
>>> c = CacheDict(cache_len=2)
>>> c[1] = 1
>>> c[2] = 2
>>> c[3] = 3
>>> c
CacheDict([(2, 2), (3, 3)])
>>> c[2]
2
>>> c[4] = 4
>>> c
CacheDict([(2, 2), (4, 4)])
>>>
References:
https://gist.github.com/davesteele/44793cd0348f59f8fadd49d7799bd306
"""
def __init__(self, *args, cache_len: int = 10, **kwargs):
assert cache_len > 0
self.cache_len = cache_len
super().__init__(*args, **kwargs)
[docs] def __setitem__(self, key, value):
super().__setitem__(key, value)
super().move_to_end(key)
while len(self) > self.cache_len:
oldkey = next(iter(self))
super().__delitem__(oldkey)
[docs] def __getitem__(self, key):
val = super().__getitem__(key)
super().move_to_end(key)
return val
[docs]GLOBAL_GDAL_CACHE = CacheDict(cache_len=32)
@ub.memoize
[docs]def _have_gdal():
try:
from osgeo import gdal
except ImportError:
return False
else:
return gdal is not None
@ub.memoize
[docs]def _have_rasterio():
try:
import rasterio
except ImportError:
return False
else:
return rasterio is not None
@ub.memoize
[docs]def _have_spectral():
try:
import spectral
except ImportError:
return False
else:
return spectral is not None
[docs]_GDAL_DTYPE_LUT = {
1: np.uint8, 2: np.uint16,
3: np.int16, 4: np.uint32, 5: np.int32,
6: np.float32, 7: np.float64, 8: np.complex_,
9: np.complex_, 10: np.complex64, 11: np.complex128
}
[docs]class LazySpectralFrameFile(ub.NiceRepr):
"""
Potentially faster than GDAL for HDR formats.
"""
def __init__(self, fpath):
self.fpath = fpath
@ub.memoize_property
[docs] def _ds(self):
import spectral
from os.path import exists
if not exists(self.fpath):
raise Exception('File does not exist: {}'.format(self.fpath))
ds = spectral.envi.open(os.fspath(self.fpath))
return ds
@classmethod
[docs] def available(self):
"""
Returns True if this backend is available
"""
return _have_spectral()
@property
[docs] def ndim(self):
return len(self.shape)
@property
[docs] def shape(self):
return self._ds.shape
@property
[docs] def dtype(self):
return self._ds.dtype
[docs] def __nice__(self):
from os.path import basename
return '.../' + basename(self.fpath)
@profile
[docs] def __getitem__(self, index):
ds = self._ds
height, width, C = ds.shape
if not ub.iterable(index):
index = [index]
index = list(index)
if len(index) < 3:
n = (3 - len(index))
index = index + [None] * n
ypart = _rectify_slice_dim(index[0], height)
xpart = _rectify_slice_dim(index[1], width)
channel_part = _rectify_slice_dim(index[2], C)
trailing_part = [channel_part]
if len(trailing_part) == 1:
channel_part = trailing_part[0]
if isinstance(channel_part, list):
band_indices = channel_part
else:
band_indices = range(*channel_part.indices(C))
else:
band_indices = range(C)
assert len(trailing_part) <= 1
ystart, ystop = map(int, [ypart.start, ypart.stop])
xstart, xstop = map(int, [xpart.start, xpart.stop])
img_part = ds.read_subregion(
row_bounds=(ystart, ystop), col_bounds=(xstart, xstop),
bands=band_indices)
return img_part
[docs]class LazyRasterIOFrameFile(ub.NiceRepr):
"""
fpath = '/home/joncrall/.cache/kwcoco/demo/large_hyperspectral/big_img_128.bsq'
lazy_rio = LazyRasterIOFrameFile(fpath)
ds = lazy_rio._ds
Ignore:
# Can rasterio read multiple bands at once?
# Seems like that is an overhead for hyperspectral images
import rasterio
riods = rasterio.open(fpath)
import timerit
ti = timerit.Timerit(1, bestof=1, verbose=2)
b = tuple(range(1, riods.count + 1))
for timer in ti.reset('rasterio'):
with timer:
riods.read(b)
lazy_rio = LazyRasterIOFrameFile(fpath)
for timer in ti.reset('LazyRasterIOFrameFile'):
with timer:
lazy_rio[:]
lazy_gdal = LazyGDalFrameFile(fpath)
for timer in ti.reset('LazyGDalFrameFile'):
with timer:
lazy_gdal[:]
"""
def __init__(self, fpath):
self.fpath = fpath
@classmethod
[docs] def available(self):
"""
Returns True if this backend is available
"""
return _have_rasterio()
@ub.memoize_property
[docs] def _ds(self):
import rasterio
from os.path import exists
if not exists(self.fpath):
raise Exception('File does not exist: {}'.format(self.fpath))
ds = rasterio.open(self.fpath, mode='r')
return ds
@property
[docs] def ndim(self):
return len(self.shape)
@property
[docs] def shape(self):
ds = self._ds
return (ds.height, ds.width, ds.count)
@property
[docs] def dtype(self):
# Assume the first is the same as the rest
ds = self._ds
dtype = getattr(np, ds.dtypes[0])
return dtype
[docs] def __nice__(self):
from os.path import basename
return '.../' + basename(self.fpath)
@profile
[docs] def __getitem__(self, index):
ds = self._ds
width = ds.width
height = ds.height
C = ds.count
if not ub.iterable(index):
index = [index]
index = list(index)
if len(index) < 3:
n = (3 - len(index))
index = index + [None] * n
ypart = _rectify_slice_dim(index[0], height)
xpart = _rectify_slice_dim(index[1], width)
channel_part = _rectify_slice_dim(index[2], C)
trailing_part = [channel_part]
if len(trailing_part) == 1:
channel_part = trailing_part[0]
if isinstance(channel_part, list):
band_indices = channel_part
else:
band_indices = range(*channel_part.indices(C))
else:
band_indices = range(C)
assert len(trailing_part) <= 1
ystart, ystop = map(int, [ypart.start, ypart.stop])
xstart, xstop = map(int, [xpart.start, xpart.stop])
indexes = [b + 1 for b in band_indices]
img_part = ds.read(indexes=indexes, window=((ystart, ystop), (xstart, xstop)))
img_part = img_part.transpose(1, 2, 0)
return img_part
[docs]def _demo_geoimg_with_nodata():
"""
Example:
from kwcoco.util.lazy_frame_backends import * # NOQA
fpath = _demo_geoimg_with_nodata()
self = LazyGDalFrameFile.demo()
"""
import kwimage
from osgeo import osr
# gdal.UseExceptions()
# Make a dummy geotiff
imdata = kwimage.grab_test_image('airport')
dpath = ub.Path.appdir('kwcoco/test/geotiff').ensuredir()
geo_fpath = dpath / 'dummy_geotiff.tif'
# compute dummy values for a geotransform to CRS84
img_h, img_w = imdata.shape[0:2]
img_box = kwimage.Boxes([[0, 0, img_w, img_h]], 'xywh')
wld_box = kwimage.Boxes([[-73.7595528, 42.6552404, 0.0001, 0.0001]], 'xywh')
img_corners = img_box.corners()
wld_corners = wld_box.corners()
transform = kwimage.Affine.fit(img_corners, wld_corners)
nodata = -9999
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326)
crs = srs.ExportToWkt()
# Set a region to be nodata
imdata = imdata.astype(np.int16)
imdata[-100:] = nodata
imdata[0:200:, -200:-180] = nodata
kwimage.imwrite(geo_fpath, imdata, backend='gdal', nodata=-9999, crs=crs, transform=transform)
return geo_fpath
[docs]class LazyGDalFrameFile(ub.NiceRepr):
"""
TODO:
- [ ] Move to its own backend module
- [ ] When used with COCO, allow the image metadata to populate the
height, width, and channels if possible.
Example:
>>> # xdoctest: +REQUIRES(module:osgeo)
>>> self = LazyGDalFrameFile.demo()
>>> print('self = {!r}'.format(self))
>>> self[0:3, 0:3]
>>> self[:, :, 0]
>>> self[0]
>>> self[0, 3]
>>> # import kwplot
>>> # kwplot.imshow(self[:])
Args:
nodata
masking_method
Example:
>>> # See if we can reproduce the INTERLEAVE bug
data = np.random.rand(128, 128, 64)
import kwimage
import ubelt as ub
from os.path import join
dpath = ub.ensure_app_cache_dir('kwcoco/tests/reader')
fpath = join(dpath, 'foo.tiff')
kwimage.imwrite(fpath, data, backend='skimage')
recon1 = kwimage.imread(fpath)
recon1.shape
self = LazyGDalFrameFile(fpath)
self.shape
self[:]
"""
def __init__(self, fpath, nodata=None):
self.fpath = fpath
self.nodata = nodata
if nodata == 'auto':
self.masking_method = 'float'
else:
self.masking_method = nodata
self._ds_cache = None
@classmethod
[docs] def available(self):
"""
Returns True if this backend is available
"""
return _have_gdal()
@profile
[docs] def _reload_cache(self):
from osgeo import gdal
_fpath = os.fspath(self.fpath)
if _fpath.endswith('.hdr'):
# Use spectral-like process to point gdal to the correct file given
# the hdr
ext = '.' + _read_envi_header(_fpath)['interleave']
_fpath = ub.augpath(_fpath, ext=ext)
if _fpath in GLOBAL_GDAL_CACHE:
self._ds_cache = GLOBAL_GDAL_CACHE[_fpath]
else:
ds = gdal.Open(_fpath, gdal.GA_ReadOnly)
if ds is None:
if not exists(self.fpath):
raise Exception('File does not exist: {}'.format(self.fpath))
raise Exception((
'GDAL Failed to open the fpath={!r} for an unknown reason. '
'Call gdal.UseExceptions() beforehand to get the '
'real exception').format(self.fpath))
self._ds_cache = ds
GLOBAL_GDAL_CACHE[_fpath] = ds
@property
[docs] def _ds(self):
if self._ds_cache is None:
self._reload_cache()
ds = self._ds_cache
return ds
@classmethod
[docs] def demo(cls, key='astro', dsize=None):
"""
Ignore:
>>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400))
"""
cache_dpath = ub.ensure_app_cache_dir('kwcoco/demo')
fpath = join(cache_dpath, key + '.cog.tiff')
depends = ub.odict(dsize=dsize)
stamp = ub.CacheStamp(fname=key, depends=depends, dpath=cache_dpath,
product=[fpath])
if stamp.expired():
img = kwimage.grab_test_image(key, dsize=dsize)
kwimage.imwrite(fpath, img, backend='gdal')
stamp.renew()
self = cls(fpath)
return self
@property
[docs] def ndim(self):
return len(self.shape)
@property
[docs] def shape(self):
# if 0:
# ds = self.ds
# INTERLEAVE = ds.GetMetadata('IMAGE_STRUCTURE').get('INTERLEAVE', '') # handle INTERLEAVE=BAND
# if INTERLEAVE == 'BAND':
# pass
# ds.GetMetadata('') # handle TIFFTAG_IMAGEDESCRIPTION
# from osgeo import gdal
# subdataset_infos = ds.GetSubDatasets()
# subdatasets = []
# for subinfo in subdataset_infos:
# path = subinfo[0]
# sub_ds = gdal.Open(path, gdal.GA_ReadOnly)
# subdatasets.append(sub_ds)
# for sub in subdatasets:
# sub.ReadAsArray()
# print((sub.RasterXSize, sub.RasterYSize, sub.RasterCount))
# sub = subdatasets[0][0]
ds = self._ds
width = ds.RasterXSize
height = ds.RasterYSize
C = ds.RasterCount
return (height, width, C)
@property
[docs] def dtype(self):
main_band = self._ds.GetRasterBand(1)
dtype = _GDAL_DTYPE_LUT[main_band.DataType]
return dtype
[docs] def __nice__(self):
from os.path import basename
return '.../' + basename(self.fpath)
@profile
[docs] def __getitem__(self, index):
"""
References:
https://gis.stackexchange.com/questions/162095/gdal-driver-create-typeerror
Ignore:
>>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400))
>>> index = [slice(2100, 2508, None), slice(4916, 5324, None), None]
>>> img_part = self[index]
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(img_part)
>>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400))
>>> self.nodata = 0
>>> index = [slice(2100, 2508, None), slice(4916, 5324, None), None]
>>> img_part = self[index]
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(img_part / 255)
Example:
>>> # Test nodata works correctly
>>> # xdoctest: +REQUIRES(module:osgeo)
>>> from kwcoco.util.lazy_frame_backends import * # NOQA
>>> from kwcoco.util.lazy_frame_backends import _demo_geoimg_with_nodata
>>> fpath = _demo_geoimg_with_nodata()
>>> self = LazyGDalFrameFile(fpath, nodata='auto')
>>> imdata = self[:]
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> import kwarray
>>> kwplot.autompl()
>>> imdata = kwimage.normalize_intensity(imdata)
>>> imdata = np.nan_to_num(imdata)
>>> kwplot.imshow(imdata)
"""
ds = self._ds
width = ds.RasterXSize
height = ds.RasterYSize
C = ds.RasterCount
if 1:
INTERLEAVE = ds.GetMetadata('IMAGE_STRUCTURE').get('INTERLEAVE', '')
if INTERLEAVE == 'BAND':
if len(ds.GetSubDatasets()) > 0:
raise NotImplementedError('Cannot handle interleaved files yet')
if not ub.iterable(index):
index = [index]
index = list(index)
if len(index) < 3:
n = (3 - len(index))
index = index + [None] * n
ypart = _rectify_slice_dim(index[0], height)
xpart = _rectify_slice_dim(index[1], width)
channel_part = _rectify_slice_dim(index[2], C)
trailing_part = [channel_part]
if len(trailing_part) == 1:
channel_part = trailing_part[0]
if isinstance(channel_part, list):
band_indices = channel_part
else:
band_indices = range(*channel_part.indices(C))
else:
band_indices = range(C)
assert len(trailing_part) <= 1
ystart, ystop = map(int, [ypart.start, ypart.stop])
xstart, xstop = map(int, [xpart.start, xpart.stop])
ysize = ystop - ystart
xsize = xstop - xstart
gdalkw = dict(xoff=xstart, yoff=ystart,
win_xsize=xsize, win_ysize=ysize)
nodata = self.nodata
needs_nodata = nodata is not None
auto_nodata = nodata == 'auto'
read_nodata = isinstance(nodata, str)
shape = (ysize, xsize, len(band_indices))
# mask_shape = (ysize, xsize, len(band_indices))
mask_shape = (ysize, xsize,)
if needs_nodata:
# TODO: can we remove the band dimension here?
mask = np.zeros(mask_shape, dtype=bool)
# preallocate like kwimage.im_io._imread_gdal
from kwimage.im_io import _gdal_to_numpy_dtype
bands = [ds.GetRasterBand(1 + band_idx)
for band_idx in band_indices]
gdal_dtype = bands[0].DataType
dtype = _gdal_to_numpy_dtype(gdal_dtype)
try:
img_part = np.empty(shape, dtype=dtype)
except ValueError:
print('ERROR')
print('self.fpath = {!r}'.format(self.fpath))
print('dtype = {!r}'.format(dtype))
print('shape = {!r}'.format(shape))
raise
for out_idx, band in enumerate(bands):
buf = band.ReadAsArray(**gdalkw)
if buf is None:
raise IOError(ub.paragraph(
'''
GDAL was unable to read band: {}, {}, with={}
from fpath={!r}
'''.format(out_idx, band, gdalkw, self.fpath)))
# print('auto_nodata = {!r}'.format(auto_nodata))
if read_nodata:
_nodata = band.GetNoDataValue()
else:
_nodata = nodata
# print('nodata = {!r}'.format(_nodata))
if _nodata is not None:
mask |= (buf == _nodata)
# mask[:, :, out_idx] = (buf == _nodata)
img_part[:, :, out_idx] = buf
buf = None
# masking_method = self.masking_method == 'auto'
# if masking_method is None:
# pass
# elif masking_method
if auto_nodata:
needs_nodata = mask.any()
if needs_nodata:
if self.masking_method == 'float':
# print('float mask')
# Hack it so nodata becomes nan
masked_hack_dtype = np.result_type(img_part.dtype, np.float32)
img_part = img_part.astype(masked_hack_dtype)
img_part[np.where(mask)] = np.nan
imdata = img_part
elif self.masking_method == 'ma':
# Using a regular masked array might be better
mask3 = np.dstack([mask] * C)
imdata = np.ma.array(img_part, mask=mask3, fill_value=None)
else:
raise NotImplementedError(self.masking_method)
else:
imdata = img_part
return imdata
[docs] def __array__(self):
"""
Allow this object to be passed to np.asarray
References:
https://numpy.org/doc/stable/user/basics.dispatch.html
"""
return self[:]
[docs]def _rectify_slice_dim(part, D):
if part is None:
return slice(0, D)
elif isinstance(part, slice):
start = 0 if part.start is None else max(0, part.start)
stop = D if part.stop is None else min(D, part.stop)
if stop < 0:
stop = D + stop
assert part.step is None
part = slice(start, stop)
return part
elif isinstance(part, int):
part = slice(part, part + 1)
elif isinstance(part, list):
part = part
else:
raise TypeError(part)
return part
[docs]def _validate_nonzero_data(file):
"""
Test to see if the image is all black.
May fail on all-black images
Example:
>>> # xdoctest: +REQUIRES(module:osgeo)
>>> import kwimage
>>> gpath = kwimage.grab_test_image_fpath()
>>> file = LazyGDalFrameFile(gpath)
>>> _validate_nonzero_data(file)
"""
try:
import numpy as np
# Find center point of the image
cx, cy = np.array(file.shape[0:2]) // 2
center = [cx, cy]
# Check if the center pixels have data, look at more data if needbe
sizes = [8, 512, 2048, 5000]
for d in sizes:
index = tuple(slice(c - d, c + d) for c in center)
partial_data = file[index]
total = partial_data.sum()
if total > 0:
break
if total == 0:
total = file[:].sum()
has_data = total > 0
except Exception:
has_data = False
return has_data