"""
Ducktyped interfaces for loading subregions of images with standard slice
syntax
"""
import os
import ubelt as ub
import numpy as np
import kwimage
from os.path import join, exists
try:
import xdev
except Exception:
profile = ub.identity
@ub.memoize
[docs]def _have_gdal():
try:
from osgeo import gdal
except ImportError:
return False
else:
return gdal is not None
@ub.memoize
[docs]def _have_rasterio():
try:
import rasterio
except ImportError:
return False
else:
return rasterio is not None
@ub.memoize
[docs]def _have_spectral():
try:
import spectral
except ImportError:
return False
else:
return spectral is not None
[docs]_GDAL_DTYPE_LUT = {
1: np.uint8, 2: np.uint16,
3: np.int16, 4: np.uint32, 5: np.int32,
6: np.float32, 7: np.float64, 8: np.complex_,
9: np.complex_, 10: np.complex64, 11: np.complex128
}
[docs]class LazySpectralFrameFile(ub.NiceRepr):
"""
Potentially faster than GDAL for HDR formats.
"""
def __init__(self, fpath):
self.fpath = fpath
@ub.memoize_property
[docs] def _ds(self):
import spectral
from os.path import exists
if not exists(self.fpath):
raise Exception('File does not exist: {}'.format(self.fpath))
ds = spectral.envi.open(os.fspath(self.fpath))
return ds
@classmethod
[docs] def available(self):
"""
Returns True if this backend is available
"""
return _have_spectral()
@property
[docs] def ndim(self):
return len(self.shape)
@property
[docs] def shape(self):
return self._ds.shape
@property
[docs] def dtype(self):
return self._ds.dtype
[docs] def __nice__(self):
from os.path import basename
return '.../' + basename(self.fpath)
@profile
[docs] def __getitem__(self, index):
ds = self._ds
height, width, C = ds.shape
if not ub.iterable(index):
index = [index]
index = list(index)
if len(index) < 3:
n = (3 - len(index))
index = index + [None] * n
ypart = _rectify_slice_dim(index[0], height)
xpart = _rectify_slice_dim(index[1], width)
channel_part = _rectify_slice_dim(index[2], C)
trailing_part = [channel_part]
if len(trailing_part) == 1:
channel_part = trailing_part[0]
if isinstance(channel_part, list):
band_indices = channel_part
else:
band_indices = range(*channel_part.indices(C))
else:
band_indices = range(C)
assert len(trailing_part) <= 1
ystart, ystop = map(int, [ypart.start, ypart.stop])
xstart, xstop = map(int, [xpart.start, xpart.stop])
img_part = ds.read_subregion(
row_bounds=(ystart, ystop), col_bounds=(xstart, xstop),
bands=band_indices)
return img_part
[docs]class LazyRasterIOFrameFile(ub.NiceRepr):
"""
fpath = '/home/joncrall/.cache/kwcoco/demo/large_hyperspectral/big_img_128.bsq'
lazy_rio = LazyRasterIOFrameFile(fpath)
ds = lazy_rio._ds
Ignore:
# Can rasterio read multiple bands at once?
# Seems like that is an overhead for hyperspectral images
import rasterio
riods = rasterio.open(fpath)
import timerit
ti = timerit.Timerit(1, bestof=1, verbose=2)
b = tuple(range(1, riods.count + 1))
for timer in ti.reset('rasterio'):
with timer:
riods.read(b)
lazy_rio = LazyRasterIOFrameFile(fpath)
for timer in ti.reset('LazyRasterIOFrameFile'):
with timer:
lazy_rio[:]
lazy_gdal = LazyGDalFrameFile(fpath)
for timer in ti.reset('LazyGDalFrameFile'):
with timer:
lazy_gdal[:]
"""
def __init__(self, fpath):
self.fpath = fpath
@classmethod
[docs] def available(self):
"""
Returns True if this backend is available
"""
return _have_rasterio()
@ub.memoize_property
[docs] def _ds(self):
import rasterio
from os.path import exists
if not exists(self.fpath):
raise Exception('File does not exist: {}'.format(self.fpath))
ds = rasterio.open(self.fpath, mode='r')
return ds
@property
[docs] def ndim(self):
return len(self.shape)
@property
[docs] def shape(self):
ds = self._ds
return (ds.height, ds.width, ds.count)
@property
[docs] def dtype(self):
# Assume the first is the same as the rest
ds = self._ds
dtype = getattr(np, ds.dtypes[0])
return dtype
[docs] def __nice__(self):
from os.path import basename
return '.../' + basename(self.fpath)
@profile
[docs] def __getitem__(self, index):
ds = self._ds
width = ds.width
height = ds.height
C = ds.count
if not ub.iterable(index):
index = [index]
index = list(index)
if len(index) < 3:
n = (3 - len(index))
index = index + [None] * n
ypart = _rectify_slice_dim(index[0], height)
xpart = _rectify_slice_dim(index[1], width)
channel_part = _rectify_slice_dim(index[2], C)
trailing_part = [channel_part]
if len(trailing_part) == 1:
channel_part = trailing_part[0]
if isinstance(channel_part, list):
band_indices = channel_part
else:
band_indices = range(*channel_part.indices(C))
else:
band_indices = range(C)
assert len(trailing_part) <= 1
ystart, ystop = map(int, [ypart.start, ypart.stop])
xstart, xstop = map(int, [xpart.start, xpart.stop])
indexes = [b + 1 for b in band_indices]
img_part = ds.read(indexes=indexes, window=((ystart, ystop), (xstart, xstop)))
img_part = img_part.transpose(1, 2, 0)
return img_part
[docs]def _demo_geoimg_with_nodata():
"""
Example:
from kwcoco.util.lazy_frame_backends import * # NOQA
fpath = _demo_geoimg_with_nodata()
self = LazyGDalFrameFile.demo()
"""
import kwimage
from osgeo import osr
# gdal.UseExceptions()
# Make a dummy geotiff
imdata = kwimage.grab_test_image('airport')
dpath = ub.Path.appdir('kwcoco/test/geotiff').ensuredir()
geo_fpath = dpath / 'dummy_geotiff.tif'
# compute dummy values for a geotransform to CRS84
img_h, img_w = imdata.shape[0:2]
img_box = kwimage.Boxes([[0, 0, img_w, img_h]], 'xywh')
wld_box = kwimage.Boxes([[-73.7595528, 42.6552404, 0.0001, 0.0001]], 'xywh')
img_corners = img_box.corners()
wld_corners = wld_box.corners()
transform = kwimage.Affine.fit(img_corners, wld_corners)
nodata = -9999
srs = osr.SpatialReference()
srs.ImportFromEPSG(4326)
crs = srs.ExportToWkt()
# Set a region to be nodata
imdata = imdata.astype(np.int16)
imdata[-100:] = nodata
imdata[0:200:, -200:-180] = nodata
kwimage.imwrite(geo_fpath, imdata, backend='gdal', nodata=-9999, crs=crs, transform=transform)
return geo_fpath
[docs]class LazyGDalFrameFile(ub.NiceRepr):
"""
TODO:
- [ ] Move to its own backend module
- [ ] When used with COCO, allow the image metadata to populate the
height, width, and channels if possible.
Example:
>>> # xdoctest: +REQUIRES(module:osgeo)
>>> self = LazyGDalFrameFile.demo()
>>> print('self = {!r}'.format(self))
>>> self[0:3, 0:3]
>>> self[:, :, 0]
>>> self[0]
>>> self[0, 3]
>>> # import kwplot
>>> # kwplot.imshow(self[:])
Args:
nodata
masking_method
Example:
>>> # See if we can reproduce the INTERLEAVE bug
data = np.random.rand(128, 128, 64)
import kwimage
import ubelt as ub
from os.path import join
dpath = ub.ensure_app_cache_dir('kwcoco/tests/reader')
fpath = join(dpath, 'foo.tiff')
kwimage.imwrite(fpath, data, backend='skimage')
recon1 = kwimage.imread(fpath)
recon1.shape
self = LazyGDalFrameFile(fpath)
self.shape
self[:]
"""
def __init__(self, fpath, nodata=None):
self.fpath = fpath
self.nodata = nodata
if nodata == 'auto':
self.masking_method = 'float'
else:
self.masking_method = nodata
@classmethod
[docs] def available(self):
"""
Returns True if this backend is available
"""
return _have_gdal()
@ub.memoize_property
[docs] def _ds(self):
from osgeo import gdal
if not exists(self.fpath):
raise Exception('File does not exist: {}'.format(self.fpath))
_fpath = os.fspath(self.fpath)
if _fpath.endswith('.hdr'):
# Use spectral-like process to point gdal to the correct file given
# the hdr
ext = '.' + _read_envi_header(_fpath)['interleave']
_fpath = ub.augpath(_fpath, ext=ext)
ds = gdal.Open(_fpath, gdal.GA_ReadOnly)
if ds is None:
raise Exception((
'GDAL Failed to open the fpath={!r} for an unknown reason. '
'Call gdal.UseExceptions() beforehand to get the '
'real exception').format(self.fpath))
return ds
@classmethod
[docs] def demo(cls, key='astro', dsize=None):
"""
Ignore:
>>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400))
"""
cache_dpath = ub.ensure_app_cache_dir('kwcoco/demo')
fpath = join(cache_dpath, key + '.cog.tiff')
depends = ub.odict(dsize=dsize)
stamp = ub.CacheStamp(fname=key, depends=depends, dpath=cache_dpath,
product=[fpath])
if stamp.expired():
img = kwimage.grab_test_image(key, dsize=dsize)
kwimage.imwrite(fpath, img, backend='gdal')
stamp.renew()
self = cls(fpath)
return self
@property
[docs] def ndim(self):
return len(self.shape)
@property
[docs] def shape(self):
# if 0:
# ds = self.ds
# INTERLEAVE = ds.GetMetadata('IMAGE_STRUCTURE').get('INTERLEAVE', '') # handle INTERLEAVE=BAND
# if INTERLEAVE == 'BAND':
# pass
# ds.GetMetadata('') # handle TIFFTAG_IMAGEDESCRIPTION
# from osgeo import gdal
# subdataset_infos = ds.GetSubDatasets()
# subdatasets = []
# for subinfo in subdataset_infos:
# path = subinfo[0]
# sub_ds = gdal.Open(path, gdal.GA_ReadOnly)
# subdatasets.append(sub_ds)
# for sub in subdatasets:
# sub.ReadAsArray()
# print((sub.RasterXSize, sub.RasterYSize, sub.RasterCount))
# sub = subdatasets[0][0]
ds = self._ds
width = ds.RasterXSize
height = ds.RasterYSize
C = ds.RasterCount
return (height, width, C)
@property
[docs] def dtype(self):
main_band = self._ds.GetRasterBand(1)
dtype = _GDAL_DTYPE_LUT[main_band.DataType]
return dtype
[docs] def __nice__(self):
from os.path import basename
return '.../' + basename(self.fpath)
@profile
[docs] def __getitem__(self, index):
"""
References:
https://gis.stackexchange.com/questions/162095/gdal-driver-create-typeerror
Ignore:
>>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400))
>>> index = [slice(2100, 2508, None), slice(4916, 5324, None), None]
>>> img_part = self[index]
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(img_part)
>>> self = LazyGDalFrameFile.demo(dsize=(6600, 4400))
>>> self.nodata = 0
>>> index = [slice(2100, 2508, None), slice(4916, 5324, None), None]
>>> img_part = self[index]
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> kwplot.autompl()
>>> kwplot.imshow(img_part / 255)
Example:
>>> # Test nodata works correctly
>>> # xdoctest: +REQUIRES(module:osgeo)
>>> from kwcoco.util.lazy_frame_backends import * # NOQA
>>> from kwcoco.util.lazy_frame_backends import _demo_geoimg_with_nodata
>>> fpath = _demo_geoimg_with_nodata()
>>> self = LazyGDalFrameFile(fpath, nodata='auto')
>>> imdata = self[:]
>>> # xdoctest: +REQUIRES(--show)
>>> import kwplot
>>> import kwarray
>>> kwplot.autompl()
>>> imdata = kwimage.normalize_intensity(imdata)
>>> imdata = np.nan_to_num(imdata)
>>> kwplot.imshow(imdata)
"""
ds = self._ds
width = ds.RasterXSize
height = ds.RasterYSize
C = ds.RasterCount
if 1:
INTERLEAVE = ds.GetMetadata('IMAGE_STRUCTURE').get('INTERLEAVE', '')
if INTERLEAVE == 'BAND':
if len(ds.GetSubDatasets()) > 0:
raise NotImplementedError('Cannot handle interleaved files yet')
if not ub.iterable(index):
index = [index]
index = list(index)
if len(index) < 3:
n = (3 - len(index))
index = index + [None] * n
ypart = _rectify_slice_dim(index[0], height)
xpart = _rectify_slice_dim(index[1], width)
channel_part = _rectify_slice_dim(index[2], C)
trailing_part = [channel_part]
if len(trailing_part) == 1:
channel_part = trailing_part[0]
if isinstance(channel_part, list):
band_indices = channel_part
else:
band_indices = range(*channel_part.indices(C))
else:
band_indices = range(C)
assert len(trailing_part) <= 1
ystart, ystop = map(int, [ypart.start, ypart.stop])
xstart, xstop = map(int, [xpart.start, xpart.stop])
ysize = ystop - ystart
xsize = xstop - xstart
gdalkw = dict(xoff=xstart, yoff=ystart,
win_xsize=xsize, win_ysize=ysize)
nodata = self.nodata
needs_nodata = nodata is not None
auto_nodata = nodata == 'auto'
read_nodata = isinstance(nodata, str)
shape = (ysize, xsize, len(band_indices))
# mask_shape = (ysize, xsize, len(band_indices))
mask_shape = (ysize, xsize,)
if needs_nodata:
# TODO: can we remove the band dimension here?
mask = np.zeros(mask_shape, dtype=bool)
# preallocate like kwimage.im_io._imread_gdal
from kwimage.im_io import _gdal_to_numpy_dtype
bands = [ds.GetRasterBand(1 + band_idx)
for band_idx in band_indices]
gdal_dtype = bands[0].DataType
dtype = _gdal_to_numpy_dtype(gdal_dtype)
try:
img_part = np.empty(shape, dtype=dtype)
except ValueError:
print('ERROR')
print('self.fpath = {!r}'.format(self.fpath))
print('dtype = {!r}'.format(dtype))
print('shape = {!r}'.format(shape))
raise
for out_idx, band in enumerate(bands):
buf = band.ReadAsArray(**gdalkw)
if buf is None:
raise IOError(ub.paragraph(
'''
GDAL was unable to read band: {}, {}, with={}
from fpath={!r}
'''.format(out_idx, band, gdalkw, self.fpath)))
# print('auto_nodata = {!r}'.format(auto_nodata))
if read_nodata:
_nodata = band.GetNoDataValue()
else:
_nodata = nodata
# print('nodata = {!r}'.format(_nodata))
if _nodata is not None:
mask |= (buf == _nodata)
# mask[:, :, out_idx] = (buf == _nodata)
img_part[:, :, out_idx] = buf
buf = None
# masking_method = self.masking_method == 'auto'
# if masking_method is None:
# pass
# elif masking_method
if auto_nodata:
needs_nodata = mask.any()
if needs_nodata:
if self.masking_method == 'float':
# print('float mask')
# Hack it so nodata becomes nan
masked_hack_dtype = np.result_type(img_part.dtype, np.float32)
img_part = img_part.astype(masked_hack_dtype)
img_part[np.where(mask)] = np.nan
imdata = img_part
elif self.masking_method == 'ma':
# Using a regular masked array might be better
mask3 = np.dstack([mask] * C)
imdata = np.ma.array(img_part, mask=mask3, fill_value=None)
else:
raise NotImplementedError(self.masking_method)
else:
imdata = img_part
return imdata
[docs] def __array__(self):
"""
Allow this object to be passed to np.asarray
References:
https://numpy.org/doc/stable/user/basics.dispatch.html
"""
return self[:]
[docs]def _rectify_slice_dim(part, D):
if part is None:
return slice(0, D)
elif isinstance(part, slice):
start = 0 if part.start is None else max(0, part.start)
stop = D if part.stop is None else min(D, part.stop)
if stop < 0:
stop = D + stop
assert part.step is None
part = slice(start, stop)
return part
elif isinstance(part, int):
part = slice(part, part + 1)
elif isinstance(part, list):
part = part
else:
raise TypeError(part)
return part
[docs]def _validate_nonzero_data(file):
"""
Test to see if the image is all black.
May fail on all-black images
Example:
>>> # xdoctest: +REQUIRES(module:osgeo)
>>> import kwimage
>>> gpath = kwimage.grab_test_image_fpath()
>>> file = LazyGDalFrameFile(gpath)
>>> _validate_nonzero_data(file)
"""
try:
import numpy as np
# Find center point of the image
cx, cy = np.array(file.shape[0:2]) // 2
center = [cx, cy]
# Check if the center pixels have data, look at more data if needbe
sizes = [8, 512, 2048, 5000]
for d in sizes:
index = tuple(slice(c - d, c + d) for c in center)
partial_data = file[index]
total = partial_data.sum()
if total > 0:
break
if total == 0:
total = file[:].sum()
has_data = total > 0
except Exception:
has_data = False
return has_data