"""
The ChannelSpec has these simple rules:
.. code::
* each 1D channel is a alphanumeric string.
* The pipe ('|') separates aligned early fused stremas (non-communative)
* The comma (',') separates late-fused streams, (happens after pipe operations, and is communative)
* Certain common sets of early fused channels have codenames, for example:
rgb = r|g|b
rgba = r|g|b|a
dxdy = dy|dy
For single arrays, the spec is always an early fused spec.
TODO:
- [X] : normalize representations? e.g: rgb = r|g|b? - OPTIONAL
- [X] : rename to BandsSpec or SensorSpec? - REJECTED
- [ ] : allow bands to be coerced, i.e. rgb -> gray, or gray->rgb
"""
import ubelt as ub
import six
import functools
[docs]class FusedChannelSpec(ub.NiceRepr):
"""
A specific type of channel spec with only one early fused stream.
The channels in this stream are non-communative
Note:
This class name and API is in flux and subject to change.
TODO:
A special code indicating a name and some number of bands that that
names contains, this would primarilly be used for large numbers of
channels produced by a network. Like:
resnet_d35d060_L5:512
or
resnet_d35d060_L5[:512]
might refer to a very specific (hashed) set of resnet parameters
with 512 bands
maybe we can do something slicly like:
resnet_d35d060_L5[A:B]
resnet_d35d060_L5:A:B
Do we want to "just store the code" and allow for parsing later?
Or do we want to ensure the serialization is parsed before we
construct the data structure?
"""
[docs] _alias_lut = {
'rgb': 'r|g|b',
'rgba': 'r|g|b|a',
'dxdy': 'dx|dy',
'fxfy': 'fx|fy',
}
[docs] _size_lut = {k: v.count('|') + 1 for k, v in _alias_lut.items()}
def __init__(self, parsed):
self.parsed = parsed
[docs] def __len__(self):
return len(self.parsed)
[docs] def __getitem__(self, index):
return self.__class__(self.parsed[index])
@classmethod
[docs] def concat(cls, items):
combined = list(ub.flatten(item.parsed for item in items))
self = cls(combined)
return self
@ub.memoize_property
[docs] def spec(self):
return '|'.join(self.parsed)
@ub.memoize
[docs] def unique(self):
return set(self.parsed)
@classmethod
[docs] def parse(cls, spec):
self = cls(spec.split('|'))
return self
@classmethod
[docs] def coerce(cls, data):
"""
Example:
>>> FusedChannelSpec.coerce(['a', 'b', 'c'])
>>> FusedChannelSpec.coerce('a|b|c')
>>> FusedChannelSpec.coerce(3)
>>> FusedChannelSpec.coerce(FusedChannelSpec(['a']))
"""
if isinstance(data, list):
self = cls(data)
elif isinstance(data, str):
self = cls.parse(data)
elif isinstance(data, int):
# we know the number of channels, but not their names
self = cls(['u{}'.format(i) for i in range(data)])
elif isinstance(data, cls):
self = data
elif isinstance(data, ChannelSpec):
parsed = data.parse()
if len(parsed) == 1:
self = cls(ub.peek(parsed.values()))
else:
raise ValueError(
'Cannot coerce ChannelSpec to a FusedChannelSpec '
'when there are multiple streams')
else:
raise TypeError('unknown type {}'.format(type(data)))
return self
[docs] def __nice__(self):
return self.spec
[docs] def __json__(self):
return self.spec
@ub.memoize_method
[docs] def normalize(self):
"""
Replace aliases with explicit single-band-per-code specs
Example:
>>> self = FusedChannelSpec.coerce('b1|b2|b3|rgb')
>>> normed = self.normalize()
>>> print('normed = {}'.format(ub.repr2(normed, nl=1)))
"""
norm_parsed = list(ub.flatten(
self._alias_lut.get(v, v).split('|')
for v in self.parsed))
normed = FusedChannelSpec(norm_parsed)
return normed
[docs] def __contains__(self, key):
"""
Example:
>>> FCS = FusedChannelSpec.coerce
>>> 'disparity' in FCS('rgb|disparity|flowx|flowy')
True
>>> 'gray' in FCS('rgb|disparity|flowx|flowy')
False
"""
return key in self.unique()
# def can_coerce(self, other):
# # return if we can coerce this band repr to another, like
# # gray to rgb or rgb to gray
[docs] def code_list(self):
"""
Return the expanded code list
"""
return self.parsed
# @ub.memoize_property
# def code_oset(self):
# return ub.oset(self.normalize().parsed)
@ub.memoize_method
[docs] def as_list(self):
return self.normalize().parsed
@ub.memoize_method
[docs] def as_oset(self):
return ub.oset(self.normalize().parsed)
@ub.memoize_method
[docs] def as_set(self):
return set(self.normalize().parsed)
[docs] def difference(self, other):
"""
Set difference
Example:
>>> FCS = FusedChannelSpec.coerce
>>> self = FCS('rgb|disparity|flowx|flowy')
>>> other = FCS('r|b')
>>> self.difference(other)
>>> other = FCS('flowx')
>>> self.difference(other)
"""
other_norm = ub.oset(other.normalize().parsed)
self_norm = ub.oset(self.normalize().parsed)
new_parsed = list(self_norm - other_norm)
new = self.__class__(new_parsed)
return new
[docs] def intersection(self, other):
"""
Example:
>>> FCS = FusedChannelSpec.coerce
>>> self = FCS('rgb|disparity|flowx|flowy')
>>> other = FCS('r|b|XX')
>>> self.intersection(other)
"""
other_norm = ub.oset(other.normalize().parsed)
self_norm = ub.oset(self.normalize().parsed)
new_parsed = list(self_norm & other_norm)
new = self.__class__(new_parsed)
return new
[docs] def component_indices(self, axis=2):
"""
Look up component indices within this stream
Example:
>>> FCS = FusedChannelSpec.coerce
>>> self = FCS('disparity|rgb|flowx|flowy')
>>> component_indices = self.component_indices()
>>> print('component_indices = {}'.format(ub.repr2(component_indices, nl=1)))
component_indices = {
'disparity': (slice(...), slice(...), slice(0, 1, None)),
'flowx': (slice(...), slice(...), slice(4, 5, None)),
'flowy': (slice(...), slice(...), slice(5, 6, None)),
'rgb': (slice(...), slice(...), slice(1, 4, None)),
}
"""
component_indices = dict()
idx1 = 0
for part in self.parsed:
size = self._size_lut.get(part, 1)
idx2 = idx1 + size
index = tuple([slice(None)] * axis + [slice(idx1, idx2)])
idx1 = idx2
component_indices[part] = index
return component_indices
[docs]class ChannelSpec(ub.NiceRepr):
"""
Parse and extract information about network input channel specs for
early or late fusion networks.
Note:
This class name and API is in flux and subject to change.
Note:
The pipe ('|') character represents an early-fused input stream, and
order matters (it is non-communative).
The comma (',') character separates different inputs streams/branches
for a multi-stream/branch network which will be lated fused. Order does
not matter
Example:
>>> # Integer spec
>>> ChannelSpec.coerce(3)
<ChannelSpec(u0|u1|u2) ...>
>>> # single mode spec
>>> ChannelSpec.coerce('rgb')
<ChannelSpec(rgb) ...>
>>> # early fused input spec
>>> ChannelSpec.coerce('rgb|disprity')
<ChannelSpec(rgb|disprity) ...>
>>> # late fused input spec
>>> ChannelSpec.coerce('rgb,disprity')
<ChannelSpec(rgb,disprity) ...>
>>> # early and late fused input spec
>>> ChannelSpec.coerce('rgb|ir,disprity')
<ChannelSpec(rgb|ir,disprity) ...>
Example:
>>> self = ChannelSpec('gray')
>>> print('self.info = {}'.format(ub.repr2(self.info, nl=1)))
>>> self = ChannelSpec('rgb')
>>> print('self.info = {}'.format(ub.repr2(self.info, nl=1)))
>>> self = ChannelSpec('rgb|disparity')
>>> print('self.info = {}'.format(ub.repr2(self.info, nl=1)))
>>> self = ChannelSpec('rgb|disparity,disparity')
>>> print('self.info = {}'.format(ub.repr2(self.info, nl=1)))
>>> self = ChannelSpec('rgb,disparity,flowx|flowy')
>>> print('self.info = {}'.format(ub.repr2(self.info, nl=1)))
Example:
>>> specs = [
>>> 'rgb', # and rgb input
>>> 'rgb|disprity', # rgb early fused with disparity
>>> 'rgb,disprity', # rgb early late with disparity
>>> 'rgb|ir,disprity', # rgb early fused with ir and late fused with disparity
>>> 3, # 3 unknown channels
>>> ]
>>> for spec in specs:
>>> print('=======================')
>>> print('spec = {!r}'.format(spec))
>>> #
>>> self = ChannelSpec.coerce(spec)
>>> print('self = {!r}'.format(self))
>>> sizes = self.sizes()
>>> print('sizes = {!r}'.format(sizes))
>>> print('self.info = {}'.format(ub.repr2(self.info, nl=1)))
>>> #
>>> item = self._demo_item((1, 1), rng=0)
>>> inputs = self.encode(item)
>>> components = self.decode(inputs)
>>> input_shapes = ub.map_vals(lambda x: x.shape, inputs)
>>> component_shapes = ub.map_vals(lambda x: x.shape, components)
>>> print('item = {}'.format(ub.repr2(item, precision=1)))
>>> print('inputs = {}'.format(ub.repr2(inputs, precision=1)))
>>> print('input_shapes = {}'.format(ub.repr2(input_shapes)))
>>> print('components = {}'.format(ub.repr2(components, precision=1)))
>>> print('component_shapes = {}'.format(ub.repr2(component_shapes, nl=1)))
"""
[docs] _alias_lut = {
'rgb': 'r|g|b',
'rgba': 'r|g|b|a',
'dxdy': 'dx|dy',
'fxfy': 'fx|fy',
}
[docs] _size_lut = {k: v.count('|') + 1 for k, v in _alias_lut.items()}
def __init__(self, spec, parsed=None):
# TODO: allow integer specs
self.spec = spec
self._info = {
'spec': spec,
'parsed': parsed,
}
[docs] def __nice__(self):
return self.spec
[docs] def __json__(self):
return self.spec
[docs] def __contains__(self, key):
"""
Example:
>>> 'disparity' in ChannelSpec('rgb,disparity,flowx|flowy')
True
>>> 'gray' in ChannelSpec('rgb,disparity,flowx|flowy')
False
"""
return key in self.unique()
@property
[docs] def info(self):
return ub.dict_union(self._info, {
'unique': self.unique(),
'normed': self.normalize(),
})
@classmethod
[docs] def coerce(cls, data):
if isinstance(data, cls):
self = data
return self
else:
if isinstance(data, int):
# we know the number of channels, but not their names
spec = '|'.join(['u{}'.format(i) for i in range(data)])
elif isinstance(data, six.string_types):
spec = data
else:
raise TypeError(type(data))
self = cls(spec)
return self
[docs] def parse(self):
"""
Build internal representation
"""
if self._info.get('parsed', None) is None:
# commas break inputs into multiple streams
stream_specs = self.spec.split(',')
parsed = {ss: ss.split('|') for ss in stream_specs}
self._info['parsed'] = parsed
return self._info['parsed']
[docs] def normalize(self):
"""
Replace aliases with explicit single-band-per-code specs
Example:
>>> self = ChannelSpec('b1|b2|b3|rgb')
>>> self.normalize()
>>> list(self.keys())
"""
new_parsed = {}
for k1, v1 in self.parse().items():
norm_vals = list(
ub.flatten(self._alias_lut.get(v, v).split('|') for v in v1))
norm_key = '|'.join(norm_vals)
new_parsed[norm_key] = norm_vals
new_spec = ','.join(list(new_parsed.keys()))
normed = ChannelSpec(new_spec, parsed=new_parsed)
return normed
# spec = self.spec
# stream_specs = spec.split(',')
# parsed = {ss: ss for ss in stream_specs}
# for k1 in parsed.keys():
# for alias, alias_spec in self._alias_lut.items():
# parsed[k1] = parsed[k1].replace(alias, alias_spec)
# parsed = {k: v.split('|') for k, v in parsed.items()}
# return parsed
[docs] def keys(self):
spec = self.spec
stream_specs = spec.split(',')
for spec in stream_specs:
yield spec
[docs] def values(self):
return self.parse().values()
[docs] def items(self):
return self.parse().items()
[docs] def streams(self):
"""
Breaks this spec up into one spec for each early-fused input stream
"""
streams = [self.__class__(spec) for spec in self.keys()]
return streams
[docs] def code_list(self):
parsed = self.parse()
if len(parsed) > 1:
raise Exception(
'Can only work on single-streams. '
'TODO make class for single streams')
return ub.peek(parsed.values())
[docs] def difference(self, other):
"""
Set difference
Example:
>>> self = ChannelSpec('rgb|disparity,flowx|flowy')
>>> other = ChannelSpec('rgb')
>>> self.difference(other)
>>> other = ChannelSpec('flowx')
>>> self.difference(other)
"""
assert len(list(other.keys())) == 1, 'can take diff with one stream'
other_norm = ub.oset(ub.peek(other.normalize().values()))
self_norm = self.normalize()
new_streams = []
for parts in self_norm.values():
new_parts = ub.oset(parts) - ub.oset(other_norm)
# shrink the representation of a complex r|g|b to an alias if
# possible.
# TODO: make this more efficient
for alias, alias_spec in self._alias_lut.items():
alias_parts = ub.oset(alias_spec.split('|'))
index = subsequence_index(new_parts, alias_parts)
if index is not None:
oset_delitem(new_parts, index)
oset_insert(new_parts, index.start, alias)
new_stream = '|'.join(new_parts)
new_streams.append(new_stream)
new_spec = ','.join(new_streams)
new = self.__class__(new_spec)
return new
[docs] def sizes(self):
"""
Number of dimensions for each fused stream channel
IE: The EARLY-FUSED channel sizes
Example:
>>> self = ChannelSpec('rgb|disparity,flowx|flowy')
>>> self.sizes()
"""
sizes = {
key: sum(self._size_lut.get(part, 1) for part in vals)
for key, vals in self.parse().items()
}
return sizes
[docs] def unique(self, normalize=False):
"""
Returns the unique channels that will need to be given or loaded
"""
import warnings
warnings.warn(
'FIXME: These kwargs are broken, but does anything use it?')
if normalize:
return set(ub.flatten(self.parse().values()))
else:
return set(ub.flatten(self.normalize().values()))
[docs] def _item_shapes(self, dims):
"""
Expected shape for an input item
Args:
dims (Tuple[int, int]): the spatial dimension
Returns:
Dict[int, tuple]
"""
item_shapes = {}
parsed = self.parse()
fused_keys = list(self.keys())
for fused_key in fused_keys:
components = parsed[fused_key]
for mode_key in components:
c = self._size_lut.get(mode_key, 1)
shape = (c,) + tuple(dims)
item_shapes[mode_key] = shape
return item_shapes
[docs] def _demo_item(self, dims=(4, 4), rng=None):
"""
Create an input that satisfies this spec
Returns:
dict: an item like it might appear when its returned from the
`__getitem__` method of a :class:`torch...Dataset`.
Example:
>>> dims = (1, 1)
>>> ChannelSpec.coerce(3)._demo_item(dims, rng=0)
>>> ChannelSpec.coerce('r|g|b|disaprity')._demo_item(dims, rng=0)
>>> ChannelSpec.coerce('rgb|disaprity')._demo_item(dims, rng=0)
>>> ChannelSpec.coerce('rgb,disaprity')._demo_item(dims, rng=0)
>>> ChannelSpec.coerce('rgb')._demo_item(dims, rng=0)
>>> ChannelSpec.coerce('gray')._demo_item(dims, rng=0)
"""
import kwarray
rng = kwarray.ensure_rng(rng)
item_shapes = self._item_shapes(dims)
item = {
key: rng.rand(*shape)
for key, shape in item_shapes.items()
}
return item
[docs] def encode(self, item, axis=0, mode=1):
"""
Given a dictionary containing preloaded components of the network
inputs, build a concatenated (fused) network representations of each
input stream.
Args:
item (Dict[str, Tensor]): a batch item containing unfused parts.
each key should be a single-stream (optionally early fused)
channel key.
axis (int, default=0): concatenation dimension
Returns:
Dict[str, Tensor]:
mapping between input stream and its early fused tensor input.
Example:
>>> from kwcoco.channel_spec import * # NOQA
>>> import numpy as np
>>> dims = (4, 4)
>>> item = {
>>> 'rgb': np.random.rand(3, *dims),
>>> 'disparity': np.random.rand(1, *dims),
>>> 'flowx': np.random.rand(1, *dims),
>>> 'flowy': np.random.rand(1, *dims),
>>> }
>>> # Complex Case
>>> self = ChannelSpec('rgb,disparity,rgb|disparity|flowx|flowy,flowx|flowy')
>>> fused = self.encode(item)
>>> input_shapes = ub.map_vals(lambda x: x.shape, fused)
>>> print('input_shapes = {}'.format(ub.repr2(input_shapes, nl=1)))
>>> # Simpler case
>>> self = ChannelSpec('rgb|disparity')
>>> fused = self.encode(item)
>>> input_shapes = ub.map_vals(lambda x: x.shape, fused)
>>> print('input_shapes = {}'.format(ub.repr2(input_shapes, nl=1)))
Example:
>>> # Case where we have to break up early fused data
>>> import numpy as np
>>> dims = (40, 40)
>>> item = {
>>> 'rgb|disparity': np.random.rand(4, *dims),
>>> 'flowx': np.random.rand(1, *dims),
>>> 'flowy': np.random.rand(1, *dims),
>>> }
>>> # Complex Case
>>> self = ChannelSpec('rgb,disparity,rgb|disparity,rgb|disparity|flowx|flowy,flowx|flowy,flowx,disparity')
>>> inputs = self.encode(item)
>>> input_shapes = ub.map_vals(lambda x: x.shape, inputs)
>>> print('input_shapes = {}'.format(ub.repr2(input_shapes, nl=1)))
>>> # xdoctest: +REQUIRES(--bench)
>>> #self = ChannelSpec('rgb|disparity,flowx|flowy')
>>> import timerit
>>> ti = timerit.Timerit(100, bestof=10, verbose=2)
>>> for timer in ti.reset('mode=simple'):
>>> with timer:
>>> inputs = self.encode(item, mode=0)
>>> for timer in ti.reset('mode=minimize-concat'):
>>> with timer:
>>> inputs = self.encode(item, mode=1)
Ignore:
import xdev
_ = xdev.profile_now(self.encode)(item, mode=1)
"""
import kwarray
if len(item) == 0:
raise ValueError('Cannot encode empty item')
_impl = kwarray.ArrayAPI.coerce(ub.peek(item.values()))
parsed = self.parse()
# unique = self.unique()
# TODO: This can be made much more efficient by determining if the
# channels item can be directly translated to the result inputs. We
# probably don't need to do the full decoding each and every time.
if mode == 1:
# Slightly more complex implementation that attempts to minimize
# concat operations.
item_keys = tuple(sorted(item.keys()))
parsed_items = tuple(sorted([(k, tuple(v)) for k, v in parsed.items()]))
new_fused_indices = _cached_single_fused_mapping(item_keys, parsed_items, axis=axis)
fused = {}
for key, idx_list in new_fused_indices.items():
parts = [item[item_key][item_sl] for item_key, item_sl in idx_list]
if len(parts) == 1:
fused[key] = parts[0]
else:
fused[key] = _impl.cat(parts, axis=axis)
elif mode == 0:
# Simple implementation that always does the full break down of
# item components.
components = {}
# Determine the layout of the channels in the input item
key_specs = {key: ChannelSpec(key) for key in item.keys()}
for key, spec in key_specs.items():
decoded = spec.decode({key: item[key]}, axis=axis)
for subkey, subval in decoded.items():
components[subkey] = subval
fused = {}
for key, parts in parsed.items():
fused[key] = _impl.cat([components[k] for k in parts], axis=axis)
else:
raise KeyError(mode)
return fused
[docs] def decode(self, inputs, axis=1):
"""
break an early fused item into its components
Args:
inputs (Dict[str, Tensor]): dictionary of components
axis (int, default=1): channel dimension
Example:
>>> from kwcoco.channel_spec import * # NOQA
>>> import numpy as np
>>> dims = (4, 4)
>>> item_components = {
>>> 'rgb': np.random.rand(3, *dims),
>>> 'ir': np.random.rand(1, *dims),
>>> }
>>> self = ChannelSpec('rgb|ir')
>>> item_encoded = self.encode(item_components)
>>> batch = {k: np.concatenate([v[None, :], v[None, :]], axis=0)
... for k, v in item_encoded.items()}
>>> components = self.decode(batch)
Example:
>>> # xdoctest: +REQUIRES(module:netharn, module:torch)
>>> import torch
>>> import numpy as np
>>> dims = (4, 4)
>>> components = {
>>> 'rgb': np.random.rand(3, *dims),
>>> 'ir': np.random.rand(1, *dims),
>>> }
>>> components = ub.map_vals(torch.from_numpy, components)
>>> self = ChannelSpec('rgb|ir')
>>> encoded = self.encode(components)
>>> from netharn.data import data_containers
>>> item = {k: data_containers.ItemContainer(v, stack=True)
>>> for k, v in encoded.items()}
>>> batch = data_containers.container_collate([item, item])
>>> components = self.decode(batch)
"""
parsed = self.parse()
components = dict()
for key, parts in parsed.items():
idx1 = 0
for part in parts:
size = self._size_lut.get(part, 1)
idx2 = idx1 + size
fused = inputs[key]
index = tuple([slice(None)] * axis + [slice(idx1, idx2)])
component = fused[index]
components[part] = component
idx1 = idx2
return components
[docs] def component_indices(self, axis=2):
"""
Look up component indices within fused streams
Example:
>>> dims = (4, 4)
>>> inputs = ['flowx', 'flowy', 'disparity']
>>> self = ChannelSpec('disparity,flowx|flowy')
>>> component_indices = self.component_indices()
>>> print('component_indices = {}'.format(ub.repr2(component_indices, nl=1)))
component_indices = {
'disparity': ('disparity', (slice(None, None, None), slice(None, None, None), slice(0, 1, None))),
'flowx': ('flowx|flowy', (slice(None, None, None), slice(None, None, None), slice(0, 1, None))),
'flowy': ('flowx|flowy', (slice(None, None, None), slice(None, None, None), slice(1, 2, None))),
}
"""
parsed = self.parse()
component_indices = dict()
for key, parts in parsed.items():
idx1 = 0
for part in parts:
size = self._size_lut.get(part, 1)
idx2 = idx1 + size
index = tuple([slice(None)] * axis + [slice(idx1, idx2)])
idx1 = idx2
component_indices[part] = (key, index)
return component_indices
@functools.lru_cache(maxsize=None)
[docs]def _cached_single_fused_mapping(item_keys, parsed_items, axis=0):
item_indices = {}
for key in item_keys:
key_idxs = _cached_single_stream_idxs(key, axis=axis)
for subkey, subsl in key_idxs.items():
item_indices[subkey] = subsl
fused_indices = {}
for key, parts in parsed_items:
fused_indices[key] = [item_indices[k] for k in parts]
new_fused_indices = {}
for key, idx_list in fused_indices.items():
# Determine which continguous slices can be merged into a
# single slice
prev_key = None
prev_sl = None
accepted = []
accum = []
for item_key, item_sl in idx_list:
if prev_key == item_key:
if prev_sl.stop == item_sl[-1].start and prev_sl.step == item_sl[-1].step:
accum.append((item_key, item_sl))
continue
if accum:
accepted.append(accum)
accum = []
prev_key = item_key
prev_sl = item_sl[-1]
accum.append((item_key, item_sl))
if accum:
accepted.append(accum)
accum = []
# Merge the accumulated contiguous slices
new_idx_list = []
for accum in accepted:
if len(accum) > 1:
item_key = accum[0][0]
first = accum[0][1]
last = accum[-1][1]
new_sl = list(first)
new_sl[-1] = slice(first[-1].start, last[-1].stop, last[-1].step)
new_sl = tuple(new_sl)
new_idx_list.append((item_key, new_sl))
else:
new_idx_list.append(accum[0])
val = new_idx_list
new_fused_indices[key] = val
return new_fused_indices
@functools.lru_cache(maxsize=None)
[docs]def _cached_single_stream_idxs(key, axis=0):
"""
Ignore:
hack for speed
axis = 0
key = 'rgb|disparity'
# xdoctest: +REQUIRES(--bench)
import timerit
ti = timerit.Timerit(100, bestof=10, verbose=2)
for timer in ti.reset('time'):
with timer:
_cached_single_stream_idxs(key, axis=axis)
for timer in ti.reset('time'):
with timer:
ChannelSpec(key).component_indices(axis=axis)
"""
# concat operations.
key_idxs = ChannelSpec(key).component_indices(axis=axis)
return key_idxs
[docs]def subsequence_index(oset1, oset2):
"""
Returns a slice into the first items indicating the position of
the second items if they exist.
This is a variant of the substring problem.
Returns:
None | slice
Example:
>>> oset1 = ub.oset([1, 2, 3, 4, 5, 6])
>>> oset2 = ub.oset([2, 3, 4])
>>> index = subsequence_index(oset1, oset2)
>>> assert index
>>> oset1 = ub.oset([1, 2, 3, 4, 5, 6])
>>> oset2 = ub.oset([2, 4, 3])
>>> index = subsequence_index(oset1, oset2)
>>> assert not index
"""
if len(oset2) == 0:
base = 0
else:
item1 = oset2[0]
try:
base = oset1.index(item1)
except (IndexError, KeyError):
base = None
index = None
if base is not None:
sl = slice(base, base + len(oset2))
subset = oset1[sl]
if subset == oset2:
index = sl
return index
[docs]def oset_insert(self, index, obj):
"""
Ignore:
self = ub.oset()
oset_insert(self, 0, 'a')
oset_insert(self, 0, 'b')
oset_insert(self, 0, 'c')
oset_insert(self, 1, 'd')
oset_insert(self, 2, 'e')
oset_insert(self, 0, 'f')
"""
if obj not in self:
# Bump index of every item after the insert position
for key in self.items[index:]:
self.map[key] = self.map[key] + 1
self.items.insert(index, obj)
self.map[obj] = index
[docs]def oset_delitem(self, index):
"""
for ubelt oset, todo contribute back to luminosoinsight
>>> self = ub.oset([1, 2, 3, 4, 5, 6, 7, 8, 9])
>>> index = slice(3, 5)
>>> oset_delitem(self, index)
Ignore:
self = ub.oset(['r', 'g', 'b', 'disparity'])
index = slice(0, 3)
oset_delitem(self, index)
"""
if isinstance(index, slice) and index == ub.orderedset.SLICE_ALL:
self.clear()
else:
if ub.orderedset.is_iterable(index):
to_remove = [self.items[i] for i in index]
elif isinstance(index, slice) or hasattr(index, "__index__"):
to_remove = self.items[index]
else:
raise TypeError("Don't know how to index an OrderedSet by %r" % index)
if isinstance(to_remove, list):
# Modified version of discard slightly more efficient for multiple
# items
remove_idxs = sorted([self.map[key] for key in to_remove], reverse=True)
for key in to_remove:
del self.map[key]
for idx in remove_idxs:
del self.items[idx]
for k, v in self.map.items():
# I think there is a more efficient way to do this?
num_after = sum(v >= i for i in remove_idxs)
if num_after:
self.map[k] = v - num_after
else:
self.discard(to_remove)