# -*- coding: utf-8 -*-
"""
The :mod:`category_tree` module defines the :class:`CategoryTree` class, which
is used for maintaining flat or hierarchical category information. The kwcoco
version of this class only contains the datastructure and does not contain any
torch operations. See the ndsampler version for the extension with torch
operations.
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import itertools as it
import networkx as nx
import ubelt as ub
import numpy as np
__all__ = ['CategoryTree']
[docs]class CategoryTree(ub.NiceRepr):
"""
Wrapper that maintains flat or hierarchical category information.
Helps compute softmaxes and probabilities for tree-based categories
where a directed edge (A, B) represents that A is a superclass of B.
Note:
There are three basic properties that this object maintains:
.. code::
node:
Alphanumeric string names that should be generally descriptive.
Using spaces and special characters in these names is
discouraged, but can be done. This is the COCO category "name"
attribute. For categories this may be denoted as (name, node,
cname, catname).
id:
The integer id of a category should ideally remain consistent.
These are often given by a dataset (e.g. a COCO dataset). This
is the COCO category "id" attribute. For categories this is
often denoted as (id, cid).
index:
Contigous zero-based indices that indexes the list of
categories. These should be used for the fastest access in
backend computation tasks. Typically corresponds to the
ordering of the channels in the final linear layer in an
associated model. For categories this is often denoted as
(index, cidx, idx, or cx).
Attributes:
idx_to_node (List[str]): a list of class names. Implicitly maps from
index to category name.
id_to_node (Dict[int, str]): maps integer ids to category names
node_to_id (Dict[str, int]): maps category names to ids
node_to_idx (Dict[str, int]): maps category names to indexes
graph (networkx.Graph): a Graph that stores any hierarchy information.
For standard mutually exclusive classes, this graph is edgeless.
Nodes in this graph can maintain category attributes / properties.
idx_groups (List[List[int]]): groups of category indices that
share the same parent category.
Example:
>>> from kwcoco.category_tree import *
>>> graph = nx.from_dict_of_lists({
>>> 'background': [],
>>> 'foreground': ['animal'],
>>> 'animal': ['mammal', 'fish', 'insect', 'reptile'],
>>> 'mammal': ['dog', 'cat', 'human', 'zebra'],
>>> 'zebra': ['grevys', 'plains'],
>>> 'grevys': ['fred'],
>>> 'dog': ['boxer', 'beagle', 'golden'],
>>> 'cat': ['maine coon', 'persian', 'sphynx'],
>>> 'reptile': ['bearded dragon', 't-rex'],
>>> }, nx.DiGraph)
>>> self = CategoryTree(graph)
>>> print(self)
<CategoryTree(nNodes=22, maxDepth=6, maxBreadth=4...)>
Example:
>>> # The coerce classmethod is the easiest way to create an instance
>>> import kwcoco
>>> kwcoco.CategoryTree.coerce(['a', 'b', 'c'])
<CategoryTree...nNodes=3, nodes=...'a', 'b', 'c'...
>>> kwcoco.CategoryTree.coerce(4)
<CategoryTree...nNodes=4, nodes=...'class_1', 'class_2', 'class_3', ...
>>> kwcoco.CategoryTree.coerce(4)
"""
def __init__(self, graph=None, checks=True):
"""
Args:
graph (nx.DiGraph):
either the graph representing a category hierarchy
checks (bool, default=True):
if false, bypass input checks
"""
if graph is None:
graph = nx.DiGraph()
elif checks:
if len(graph) > 0:
if not nx.is_directed_acyclic_graph(graph):
raise ValueError('The category graph must a DAG')
if not nx.is_forest(graph):
raise ValueError('The category graph must be a forest')
if not isinstance(graph, nx.Graph):
raise TypeError('Input to CategoryTree must be a networkx graph not {}'.format(type(graph)))
self.graph = graph # :type: nx.Graph
# Note: nodes are class names
self.id_to_node = None
self.node_to_id = None
self.node_to_idx = None
self.idx_to_node = None
self.idx_groups = None
self._build_index()
[docs] def copy(self):
new = self.__class__(self.graph.copy())
return new
@classmethod
[docs] def from_mutex(cls, nodes, bg_hack=True):
"""
Args:
nodes (List[str]): or a list of class names (in which case they
will all be assumed to be mutually exclusive)
Example:
>>> print(CategoryTree.from_mutex(['a', 'b', 'c']))
<CategoryTree(nNodes=3, ...)>
"""
nodes = list(nodes)
graph = nx.DiGraph()
graph.add_nodes_from(nodes)
start = 0
if bg_hack:
# TODO: deprecate the defaultness of this
if 'background' in graph.nodes:
# hack
graph.nodes['background']['id'] = 0
start = 1
for i, node in enumerate(nodes, start=start):
graph.nodes[node]['id'] = graph.nodes[node].get('id', i)
return cls(graph)
@classmethod
[docs] def from_json(cls, state):
"""
Args:
state (Dict): see __getstate__ / __json__ for details
"""
self = cls()
self.__setstate__(state)
return self
@classmethod
[docs] def from_coco(cls, categories):
"""
Create a CategoryTree object from coco categories
Args:
List[Dict]: list of coco-style categories
"""
graph = nx.DiGraph()
for cat in categories:
graph.add_node(cat['name'], **cat)
if cat.get('supercategory', None) is not None:
graph.add_edge(cat['supercategory'], cat['name'])
self = cls(graph)
return self
@classmethod
[docs] def coerce(cls, data, **kw):
"""
Attempt to coerce data as a CategoryTree object.
This is primarily useful for when the software stack depends on
categories being represent
This will work if the input data is a specially formatted json dict, a
list of mutually exclusive classes, or if it is already a CategoryTree.
Otherwise an error will be thrown.
Args:
data (object): a known representation of a category tree.
**kwargs: input type specific arguments
Returns:
CategoryTree: self
Raises:
TypeError - if the input format is unknown
ValueError - if kwargs are not compatible with the input format
Example:
>>> import kwcoco
>>> classes1 = kwcoco.CategoryTree.coerce(3) # integer
>>> classes2 = kwcoco.CategoryTree.coerce(classes1.__json__()) # graph dict
>>> classes3 = kwcoco.CategoryTree.coerce(['class_1', 'class_2', 'class_3']) # mutex list
>>> classes4 = kwcoco.CategoryTree.coerce(classes1.graph) # nx Graph
>>> classes5 = kwcoco.CategoryTree.coerce(classes1) # cls
>>> # xdoctest: +REQUIRES(module:ndsampler)
>>> import ndsampler
>>> classes6 = ndsampler.CategoryTree.coerce(3)
>>> classes7 = ndsampler.CategoryTree.coerce(classes1)
>>> classes8 = kwcoco.CategoryTree.coerce(classes6)
"""
if isinstance(data, int):
# An integer specifies the number of classes.
self = cls.from_mutex(
['class_{}'.format(i + 1) for i in range(data)], **kw)
elif isinstance(data, dict):
# A dictionary is assumed to be in a special json format
self = cls.from_json(data, **kw)
elif isinstance(data, list):
# A list is assumed to be a list of class names
self = cls.from_mutex(data, **kw)
elif isinstance(data, nx.DiGraph):
# A nx.DiGraph should represent the category tree
self = cls(data, **kw)
elif isinstance(data, cls):
# If data is already a CategoryTree, do nothing and just return it
self = data
if len(kw):
raise ValueError(
'kwargs cannot with this cls={}, type(data)={}'.format(
cls, type(data)))
elif issubclass(cls, type(data)):
# If we are an object that inherits from kwcoco.CategoryTree (e.g.
# ndsampler.CategoryTree), but we are given a raw
# kwcoco.CategoryTree, we need to try and upgrade the data
# structure.
self = cls(data.graph)
else:
raise TypeError(
'Unknown type cls={}, type(data)={}: data={!r}'.format(
cls, type(data), data))
return self
@classmethod
[docs] def demo(cls, key='coco', **kwargs):
"""
Args:
key (str): specify which demo dataset to use.
Can be 'coco' (which uses the default coco demo data).
Can be 'btree' which creates a binary tree and accepts kwargs 'r' and 'h' for branching-factor and height.
Can be 'btree2', which is the same as btree but returns strings
CommandLine:
xdoctest -m ~/code/kwcoco/kwcoco/category_tree.py CategoryTree.demo
Example:
>>> from kwcoco.category_tree import *
>>> self = CategoryTree.demo()
>>> print('self = {}'.format(self))
self = <CategoryTree(nNodes=10, maxDepth=2, maxBreadth=4...)>
"""
if key == 'coco':
from kwcoco import coco_dataset
dset = coco_dataset.CocoDataset.demo(**kwargs)
dset.add_category('background', id=0)
graph = dset.category_graph()
elif key == 'btree':
r = kwargs.pop('r', 3)
h = kwargs.pop('h', 3)
graph = nx.generators.balanced_tree(r=r, h=h, create_using=nx.DiGraph())
graph = nx.relabel_nodes(graph, {n: n + 1 for n in graph})
if kwargs.pop('add_zero', True):
graph.add_node(0)
assert not kwargs
elif key == 'btree2':
r = kwargs.pop('r', 3)
h = kwargs.pop('h', 3)
graph = nx.generators.balanced_tree(r=r, h=h, create_using=nx.DiGraph())
graph = nx.relabel_nodes(graph, {n: str(n + 1) for n in graph})
if kwargs.pop('add_zero', True):
graph.add_node(str(0))
assert not kwargs
elif key == 'animals_v1':
graph = nx.from_dict_of_lists({
'background': [],
'foreground': ['animal'],
'animal': ['mammal', 'fish', 'insect', 'reptile'],
'mammal': ['dog', 'cat', 'human', 'zebra'],
'zebra': ['grevys', 'plains'],
'grevys': ['fred'],
'dog': ['boxer', 'beagle', 'golden'],
'cat': ['maine coon', 'persian', 'sphynx'],
'reptile': ['bearded dragon', 't-rex'],
}, nx.DiGraph)
else:
raise KeyError(key)
self = cls(graph)
return self
[docs] def to_coco(self):
"""
Converts to a coco-style data structure
Yields:
Dict: coco category dictionaries
"""
for cid, node in self.id_to_node.items():
# Skip if background already added
cat = {
'id': cid,
'name': node,
}
parents = list(self.graph.predecessors(node))
if len(parents) == 1:
cat['supercategory'] = parents[0]
else:
if len(parents) > 1:
raise Exception('not a tree')
yield cat
@ub.memoize_property
[docs] def id_to_idx(self):
"""
Example:
>>> import kwcoco
>>> self = kwcoco.CategoryTree.demo()
>>> self.id_to_idx[1]
"""
return _calldict({cid: self.node_to_idx[node]
for cid, node in self.id_to_node.items()})
@ub.memoize_property
[docs] def idx_to_id(self):
"""
Example:
>>> import kwcoco
>>> self = kwcoco.CategoryTree.demo()
>>> self.idx_to_id[0]
"""
return [self.node_to_id[node]
for node in self.idx_to_node]
@ub.memoize_method
[docs] def idx_to_ancestor_idxs(self, include_self=True):
"""
Mapping from a class index to its ancestors
Args:
include_self (bool, default=True):
if True includes each node as its own ancestor.
"""
lut = {
idx: set(ub.take(self.node_to_idx, nx.ancestors(self.graph, node)))
for idx, node in enumerate(self.idx_to_node)
}
if include_self:
for idx, idxs in lut.items():
idxs.update({idx})
return lut
@ub.memoize_method
[docs] def idx_to_descendants_idxs(self, include_self=False):
"""
Mapping from a class index to its descendants (including itself)
Args:
include_self (bool, default=False):
if True includes each node as its own descendant.
"""
lut = {
idx: set(ub.take(self.node_to_idx, nx.descendants(self.graph, node)))
for idx, node in enumerate(self.idx_to_node)
}
if include_self:
for idx, idxs in lut.items():
idxs.update({idx})
return lut
@ub.memoize_method
[docs] def idx_pairwise_distance(self):
"""
Get a matrix encoding the distance from one class to another.
Distances
* from parents to children are positive (descendants),
* from children to parents are negative (ancestors),
* between unreachable nodes (wrt to forward and reverse graph) are nan.
"""
pdist = np.full((len(self), len(self)), fill_value=-np.nan,
dtype=np.float32)
for node1, dists in nx.all_pairs_shortest_path_length(self.graph):
idx1 = self.node_to_idx[node1]
for node2, dist in dists.items():
idx2 = self.node_to_idx[node2]
pdist[idx1, idx2] = dist
pdist[idx2, idx1] = -dist
return pdist
[docs] def __len__(self):
return len(self.graph)
[docs] def __iter__(self):
return iter(self.idx_to_node)
[docs] def __getitem__(self, index):
return self.idx_to_node[index]
[docs] def __contains__(self, node):
return node in self.idx_to_node
[docs] def __json__(self):
"""
Example:
>>> import pickle
>>> self = CategoryTree.demo()
>>> print('self = {!r}'.format(self.__json__()))
"""
return self.__getstate__()
[docs] def __getstate__(self):
"""
Serializes information in this class
Example:
>>> from kwcoco.category_tree import *
>>> import pickle
>>> self = CategoryTree.demo()
>>> state = self.__getstate__()
>>> serialization = pickle.dumps(self)
>>> recon = pickle.loads(serialization)
>>> assert recon.__json__() == self.__json__()
"""
state = self.__dict__.copy()
for key in list(state.keys()):
if key.startswith('_cache'):
state.pop(key)
state['graph'] = to_directed_nested_tuples(self.graph)
if True:
# Remove reundant items
state.pop('node_to_idx')
state.pop('node_to_id')
state.pop('idx_groups')
return state
[docs] def __setstate__(self, state):
graph = from_directed_nested_tuples(state['graph'])
need_reindex = False
if True:
# Reconstruct redundant items
if 'node_to_idx' not in state:
state['node_to_idx'] = {node: idx for idx, node in
enumerate(state['idx_to_node'])}
if 'node_to_id' not in state:
state['node_to_id'] = {node: id for id, node in
state['id_to_node'].items()}
if 'idx_groups' not in state:
node_groups = list(traverse_siblings(graph))
node_to_idx = state['node_to_idx']
try:
state['idx_groups'] = [sorted([node_to_idx[n] for n in group])
for group in node_groups]
except KeyError:
need_reindex = True
pass
self.__dict__.update(state)
self.graph = graph
if need_reindex:
self._build_index()
[docs] def __nice__(self):
max_depth = tree_depth(self.graph)
if max_depth > 1:
max_breadth = max(it.chain([0], map(len, self.idx_groups)))
text = 'nNodes={}, maxDepth={}, maxBreadth={}, nodes={}'.format(
self.num_classes, max_depth, max_breadth, self.idx_to_node,
)
else:
text = 'nNodes={}, nodes={}'.format(
self.num_classes, self.idx_to_node,
)
return text
[docs] def is_mutex(self):
"""
Returns True if all categories are mutually exclusive (i.e. flat)
If true, then the classes may be represented as a simple list of class
names without any loss of information, otherwise the underlying
category graph is necessary to preserve all knowledge.
TODO:
- [ ] what happens when we have a dummy root?
"""
return len(self.graph.edges) == 0
@property
[docs] def num_classes(self):
return self.graph.number_of_nodes()
@property
[docs] def class_names(self):
return self.idx_to_node
@property
[docs] def category_names(self):
return self.idx_to_node
@property
[docs] def cats(self):
"""
Returns a mapping from category names to category attributes.
If this category tree was constructed from a coco-dataset, then this
will contain the coco category attributes.
Returns:
Dict[str, Dict[str, object]]
Example:
>>> from kwcoco.category_tree import *
>>> self = CategoryTree.demo()
>>> print('self.cats = {!r}'.format(self.cats))
"""
return dict(self.graph.nodes)
[docs] def index(self, node):
""" Return the index that corresponds to the category name """
return self.node_to_idx[node]
[docs] def _build_index(self):
""" construct lookup tables """
# Most of the categories should have been given integer ids
max_id = max(it.chain([0], nx.get_node_attributes(self.graph, 'id').values()))
# Fill in id-values for any node that doesn't have one
node_to_id = {}
for node, attrs in sorted(self.graph.nodes.items()):
node_to_id[node] = attrs.get('id', max_id + 1)
max_id = max(max_id, node_to_id[node])
id_to_node = ub.invert_dict(node_to_id)
# Compress ids into a flat index space (sorted by node ids)
idx_to_node = ub.argsort(node_to_id)
node_to_idx = {node: idx for idx, node in enumerate(idx_to_node)}
# Find the sets of nodes that need to be softmax-ed together
node_groups = list(traverse_siblings(self.graph))
idx_groups = [sorted([node_to_idx[n] for n in group])
for group in node_groups]
# Set instance attributes
self.id_to_node = id_to_node
self.node_to_id = node_to_id
self.idx_to_node = idx_to_node
self.node_to_idx = node_to_idx
self.idx_groups = idx_groups
[docs] def show(self):
"""
Ignore:
>>> import kwplot
>>> kwplot.autompl()
>>> from kwcoco import category_tree
>>> self = category_tree.CategoryTree.demo()
>>> self.show()
python -c "import kwplot, kwcoco, graphid; kwplot.autompl(); graphid.util.show_nx(kwcoco.category_tree.CategoryTree.demo().graph); kwplot.show_if_requested()" --show
"""
try:
pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, prog='dot')
except ImportError:
import warnings
warnings.warn('pygraphviz is not available')
pos = None
nx.draw_networkx(self.graph, pos=pos)
# import graphid
# graphid.util.show_nx(self.graph)
[docs] def forest_str(self):
import networkx as nx
text = nx.forest_str(self.graph)
# print(text)
return text
[docs] def normalize(self):
"""
Applies a normalization scheme to the categories.
Note: this may break other tasks that depend on exact category names.
Returns:
CategoryTree
Example:
>>> from kwcoco.category_tree import * # NOQA
>>> import kwcoco
>>> orig = kwcoco.CategoryTree.demo('animals_v1')
>>> self = kwcoco.CategoryTree(nx.relabel_nodes(orig.graph, str.upper))
>>> norm = self.normalize()
"""
# nx.adjacency_data(self.graph)
def normalize_name(name):
return name.lower().replace(' ', '')
new_graph = self.graph.__class__()
node_mapping = {}
new_nodes = []
for old_node, old_data in self.graph.nodes(data=True):
new_node = normalize_name(old_node)
new_data = old_data.copy()
if 'id' not in old_data:
new_data['id'] = self.node_to_id[old_node]
if 'supercategory' in old_data:
new_data['supercategory'] = normalize_name(old_data['supercategory'])
if 'name' in old_data:
new_data['name'] = normalize_name(old_data['name'])
new_nodes.append((new_node, new_data))
node_mapping[old_node] = new_node
new_graph.add_node(new_node, **new_data)
for old_u, old_v, old_data in self.graph.edges(data=True):
new_u = node_mapping[old_u]
new_v = node_mapping[old_v]
new_data = old_data.copy()
new_graph.add_edge(new_u, new_v, **new_data)
new = self.__class__(new_graph)
return new
# json_data = nx.node_link_data(self.graph)
# for path, data in ub.IndexableWalker(json_data):
# pass
# nx.cytoscape_data(self.graph)
# nx.node_link_graph(self.graph)
# to_directed_nested_tuples(self.graph)
# for
# name.lower().replace(' ', '_')
# self.idx_to_node
def source_nodes(graph):
""" generates source nodes --- nodes without incoming edges """
return (n for n in graph.nodes() if graph.in_degree(n) == 0)
def sink_nodes(graph):
""" generates source nodes --- nodes without incoming edges """
return (n for n in graph.nodes() if graph.out_degree(n) == 0)
def traverse_siblings(graph, sources=None):
""" generates groups of nodes that have the same parent """
if sources is None:
sources = list(source_nodes(graph))
yield sources
for node in sources:
children = list(graph.successors(node))
if children:
for _ in traverse_siblings(graph, children):
yield _
def tree_depth(graph, root=None):
"""
Maximum depth of the forest / tree
Example:
>>> from kwcoco.category_tree import *
>>> graph = nx.balanced_tree(r=2, h=3, create_using=nx.DiGraph)
>>> tree_depth(graph)
4
>>> tree_depth(nx.balanced_tree(r=2, h=0, create_using=nx.DiGraph))
1
"""
if len(graph) == 0:
return 0
if root is not None:
assert root in graph.nodes
assert nx.is_forest(graph)
def _inner(root):
if root is None:
return max(it.chain([0], (_inner(n) for n in source_nodes(graph))))
else:
return max(it.chain([0], (_inner(n) for n in graph.successors(root)))) + 1
depth = _inner(root)
return depth
def to_directed_nested_tuples(graph, with_data=True):
"""
Serialize a networkx graph.
Encodes each node and its children in a tuple as:
(node, children)
"""
def _represent_node(node):
if with_data:
node_data = graph.nodes[node]
return (node, node_data, _traverse_encode(node))
else:
return (node, _traverse_encode(node))
def _traverse_encode(parent):
children = sorted(graph.successors(parent))
# graph.get_edge_data(node, child)
return [_represent_node(node) for node in children]
sources = sorted(source_nodes(graph))
encoding = [_represent_node(node) for node in sources]
return encoding
def from_directed_nested_tuples(encoding):
"""
Unserialize a networkx graph.
Example:
>>> from kwcoco.category_tree import *
>>> graph = nx.generators.gnr_graph(20, 0.3, seed=790).reverse()
>>> graph.nodes[0]['color'] = 'black'
>>> encoding = to_directed_nested_tuples(graph)
>>> recon = from_directed_nested_tuples(encoding)
>>> recon_encoding = to_directed_nested_tuples(recon)
>>> assert recon_encoding == encoding
"""
node_data_view = {}
def _traverse_recon(tree):
nodes = []
edges = []
for tup in tree:
if len(tup) == 2:
node, subtree = tup
elif len(tup) == 3:
node, node_data, subtree = tup
node_data_view[node] = node_data
else:
raise AssertionError('invalid tup')
children = [t[0] for t in subtree]
nodes.append(node)
edges.extend((node, child) for child in children)
subnodes, subedges = _traverse_recon(subtree)
nodes.extend(subnodes)
edges.extend(subedges)
return nodes, edges
nodes, edges = _traverse_recon(encoding)
graph = nx.DiGraph()
graph.add_nodes_from(nodes)
graph.add_edges_from(edges)
for k, v in node_data_view.items():
graph.nodes[k].update(v)
return graph
class _calldict(dict):
"""
helper object to maintain backwards compatibility between new and old
id_to_idx methods.
Example:
>>> self = _calldict({1: 2})
>>> #assert self()[1] == 2
>>> assert self[1] == 2
"""
def __call__(self):
import warnings
warnings.warn('Calling id_to_idx as a method has been depricated. '
'Use this dict as a property')
return self
if __name__ == '__main__':
"""
CommandLine:
xdoctest -m kwcoco.category_tree
"""
import xdoctest
xdoctest.doctest_module(__file__)