Source code for arthropod_describer.common.label_hierarchy

import functools
import json
import math
import operator
from pathlib import Path
from typing import List, Optional, Dict, Union, Any, Tuple, Set


[docs]class Node: """Label representation.""" def __init__(self): self.parent: Optional[Node] = None self.children: List[Node] = [] self.label: int = -1 self.code: str = '' self.name: str = '' self.color: Tuple[int, int, int] = (255, 255, 255) # for now
[docs] def to_dict(self) -> Dict[str, Any]: return { 'name': self.name, 'code': self.code, 'color': { 'red': self.color[0], 'green': self.color[1], 'blue': self.color[2] } }
[docs]class LabelHierarchy: """Class representing a particular hierarchy of labels. Lorem ipsum dolor sit amet all the waaaaaaaaaaaay. """ ROOT: int = -1 def __init__(self): self.masks: List[int] = [] # bit masks for all label levels, len(self.masks) = number of levels in hierarchy self.n_bits: int = 0 # how many bits are allocated for each label in the hierarchy self.named_masks: Dict[str, int] = {} # a mapping; mask_name -> mask; self.named_masks[mask_name] = mask, where mask_name from self.mask_names, mask from self.masks self.mask_names: List[str] = [] # human-friendly names of masks, eg. Animal, Segments, corresponding to label levels self.masks_names: Dict[int, str] = {} # a mapping; mask -> mask_name, reverse of self.named_masks self.counts_of_bits: List[int] = [] # a distribution of `n_bits` to `len(self.masks)` levels, self.counts_of_bits[i] = x means, x bits are allocated to i-th mask/level self.shifts: List[int] = [] # bit shifts for each mask, `self.shifts[0] = self.n_bits`, self.shifts[i] = self.shifts[i-1] - self.counts_of_bits[i-1] self.whole_mask: int = 0 # this should be 2^self.n_bits - 1 self.sep: str = ':' # separator for code representation of labels, e.g. 1:1:2:0 self.labels: List[int] = [] # list of labels that are defined, ie. that make sense to use for a specific case self.children: Dict[int, List[int]] = {} # a mapping; label -> list of label's children self.parents: Dict[int, int] = {} # a mappping; label -> label's parent self.nodes: Dict[int, Node] = {} # a mapping; label -> label's node representation #self.colormap: Optional[Colormap] = None self.colormap: Dict[int, Tuple[int, int, int]] = {} self.name = '' self._level_groups: List[Set[int]] = [] self.mask_label: Optional[Node] = None
[docs] @classmethod def are_valid_masks(cls, masks: List[int], n_bits: int = 32) -> bool: """ Checks whether `masks` represents a valid distribution of masks for `n_bits`-bit labels For `masks` to be a valid list of masks: 1. each mask must have only one sequence of ones, e.g. 11110000 is valid, 11110011 is not valid 2. no two distinct masks can overlap, e.g. 11110000, 00001111 is a valid distribution; 11110000, 00111111 is not a valid mask distribution 3. `masks[i]` and `masks[i+1]` must have their one-bit sequences adjacent, ie. 11110000 00001111 is valid 11110000, 00000111 is not valid 4. bit-wise union of all masks must equal to `2^n_bits - 1` """ if not all(map(functools.partial(cls.has_unique_sequence_of_ones, n_bits=n_bits), masks)): return False return not cls.masks_overlap(list(sorted(masks))) and cls.masks_are_adjacent(masks) and cls.masks_cover_nbits( masks, n_bits)
[docs] @classmethod def masks_overlap(cls, masks: List[int]) -> bool: """Checks whether the sorted masks in `masks` have nonempty bitwise intersection.""" masks_ = list(sorted(masks)) for i in range(len(masks_) - 1): if masks_[i] & masks_[i+1]: return True return False
[docs] @classmethod def masks_are_adjacent(cls, masks: List[int]) -> bool: """Checks whether the neighboring masks in sorted(`masks`) are adjacent w.r.t. to their one-bit sequences. see 3. in the docstring of `are_valid_masks`""" masks_ = list(sorted(masks)) for i in range(len(masks_) - 1): if not (masks_[i] & (masks_[i+1] >> 1)): return False return True
[docs] @classmethod def has_unique_sequence_of_ones(cls, mask: int, n_bits: int = 32) -> bool: """Checks whether `mask` contains a single contiguous run of `1`-s.""" found_seq_already = False # is True when 0 bit is encountered after a sequence of 1 bits ones = False # is switched to True when encountered the first 1 bit for i in range(n_bits): bit = (mask >> i) & 1 if bit: if not found_seq_already: # The first 1bit encountered, marking the first sequence of 1bits ones = True found_seq_already = True else: if not ones: # This means we encountered a second sequence of 1bits, making the mask an invalid one return False else: ones = False return True
[docs] @classmethod def create(cls, masks: List[int], mask_names: Optional[List[str]] = None, n_bits: int = 32) -> \ Optional['LabelHierarchy']: """Creates a new label hierarchy with the given masks in `masks`, named by `mask_names` and with a total of `n_bits` bits per mask.""" if not cls.are_valid_masks(masks, n_bits): return None # TODO raise Exception? if mask_names is None: mask_names = [f'level {i}' for i in range(len(masks))] hier = LabelHierarchy() hier.whole_mask = 2**n_bits - 1 hier.masks = list(sorted(masks, reverse=True)) # sort the masks in DESC hier.counts_of_bits = [cls._bit_count(mask) for mask in hier.masks] # infer the count of bits of each mask start = n_bits for bit_count in hier.counts_of_bits: # derive the bit shift for each mask. start -= bit_count hier.shifts.append(start) hier.named_masks = {name: mask for name, mask in zip(mask_names, hier.masks)} hier.n_bits = n_bits hier.mask_names = mask_names hier.masks_names = {mask: name for name, mask in hier.named_masks.items()} for _ in range(len(masks)): hier._level_groups.append(set()) return hier
[docs] @classmethod def designate_bits(cls, counts_of_bits: List[int], mask_names: Optional[List[str]] = None, n_bits: int = 32) -> \ Optional['LabelHierarchy']: """Generates `len(counts_of_bits)` masks and designates each of them its corresponding number of bits as specified in `counts_of_bits` out of total number of bits as specified in `n_bits`. Optionally gives them names as specified in `mask_names`.""" if mask_names is None: mask_names = [f'level {i}' for i in range(len(counts_of_bits))] masks = [] start = n_bits for count_of_bits in counts_of_bits: start -= count_of_bits masks.append((2**count_of_bits - 1) << start) return cls.create(masks, mask_names=mask_names, n_bits=n_bits)
[docs] @classmethod def masks_cover_nbits(cls, masks: List[int], n_bits: int = 32) -> bool: """Checks whether the total number of 1-bits in all masks in `masks` is equal to `n_bits`.""" mask = functools.reduce(operator.or_, masks) return mask == 2**n_bits - 1
[docs] @classmethod def load(cls, path_to_json: Path) -> Optional['LabelHierarchy']: if not path_to_json.exists(): return None with open(path_to_json) as f: label_hier_info = json.load(f) #counts_of_bits = label_hier_info['counts_of_bits'] #mask_names = label_hier_info['mask_names'] #return cls.designate_bits(counts_of_bits, mask_names=mask_names, n_bits=sum(counts_of_bits)) return cls.from_dict(label_hier_info)
[docs] def save(self, path_to_json: Path): with open(path_to_json, 'w') as f: json.dump({'counts_of_bits': self.counts_of_bits, 'mask_names': self.mask_names}, f)
[docs] def get_level_name(self, label_or_code: Union[int, str]) -> str: """Returns the name of the level that the `label` as specified by `label_or_code` belongs to. :param label_or_code: either the integer representation of the label or its textual code. :type label_or_code: int or str :return: name of the level the input label belongs to :rtype: str """ if type(label_or_code) == str: label = self.label(label_or_code) else: label = label_or_code return self.mask_names[self.get_level(label)]
#for level in range(len(self.masks) - 1, -1, -1): # if label & self.masks[level]: # return self.mask_names[level]
[docs] def get_level(self, label: int) -> int: """Returns the level the label `label` belongs to in this hierarchy.""" if label == 0: return 0 if label < 0: return -1 for level in range(len(self.masks) - 1, -1, -1): if label & self.masks[level]: return level
[docs] def get_mask(self, label: int) -> int: """Returns the mask of the level that the `label` belongs to.""" return self.masks[self.get_level(label)]
[docs] def label_mask(self, label: int) -> int: """ Returns the union of self.masks[0],...,self.masks[self.get_level(label)] """ level = self.get_level(label) mask = 0 for i in range(level+1): mask = mask | self.masks[i] return mask
[docs] def get_parent(self, label_or_code: Union[int, str]) -> int: """Returns the parent label of the label represented by `label_or_code`.""" if type(label_or_code) == str: label = self.label(label_or_code) else: label = label_or_code # #mask = self.get_mask(label) # #index = self.masks.index(mask) # return self.get_level(label) - 1 # For top level mask (index 0), -1 is correct, meaning it does not have a parent return self.parents[label]
[docs] def get_ancestors(self, label: int) -> List[int]: """Returns the list of ancestor labels for `label`.""" parents = [] curr_label = label while (parent := self.parents[curr_label]) > 0: parents.append(parent) curr_label = parent # parents.append(-1) return parents
[docs] def code(self, label: int) -> str: """Returns the textual code representation of `label`.""" str_code = str((self.masks[0] & label) >> self.shifts[0]) + self.sep for i in range(1, len(self.masks)): str_code += str((self.masks[i] & label) >> self.shifts[i]) + self.sep return str_code[:-1]
[docs] def label(self, code: str) -> int: """Returns the `label` that is represented by `code`.""" labels = code.split(self.sep) bits = [int(label) << shift for label, shift in zip(labels, self.shifts)] return functools.reduce(operator.or_, bits)
[docs] def level_mask(self, level: int) -> int: """ Returns the union of self.masks[0],...,self.masks[level]. :param level: Number indicating the level a level mask should be returned for. :type level: int :return: The mask for the level. :rtype: int """ mask = 0 for l in range(level + 1): mask |= self.masks[l] return mask
[docs] def get_label_mask_up_to(self, level: int, label: int) -> int: """Returns the bitwise AND of `level_mask(level) and `label`.""" mask = 0 for l in range(level + 1): mask |= self.masks[l] return mask & label
[docs] def set_labels(self, labels: List[int]): self.labels = labels.copy() self.nodes.clear() for label in self.labels: node = Node() node.label = label node.code = self.code(label) self.nodes[label] = node self._level_groups[self.get_level(label)].add(label) self.compute_children()
[docs] def compute_children(self): """Establishes descendant relations between label nodes.""" self._create_root_node() if -1 in self.labels: self.labels.remove(-1) #first_idx = 1 if self.labels[0] == 0 else 0 parent_label = -1 label = parent_label depth = -1 # on depth 1 we look for children of `label` which is on level 0 stack = [parent_label] for curr_label in self.labels: curr_depth = self.get_level(curr_label) self.children.setdefault(curr_label, []) if curr_depth == depth: self.nodes[parent_label].children.append(self.nodes[curr_label]) self.nodes[curr_label].parent = self.nodes[parent_label] self.children.setdefault(parent_label, list()).append(curr_label) self.parents[curr_label] = parent_label label = curr_label elif curr_depth > depth: parent_label = label stack.append(parent_label) self.nodes[parent_label].children.append(self.nodes[curr_label]) self.nodes[curr_label].parent = self.nodes[parent_label] self.children.setdefault(parent_label, list()).append(curr_label) self.parents[curr_label] = parent_label label = curr_label depth = curr_depth elif curr_depth < depth: for _ in range(depth - curr_depth): stack.pop() parent_label = stack[-1] self.nodes[parent_label].children.append(self.nodes[curr_label]) self.nodes[curr_label].parent = self.nodes[parent_label] self.children.setdefault(parent_label, list()).append(curr_label) self.parents[curr_label] = parent_label label = curr_label depth = curr_depth
def _create_root_node(self): """Creates a root node. This label node is not used by the user and should not appear anywhere visible.""" if -1 in self.nodes: return root = Node() root.label = -1 root.code = '-1:0:0:0' # whatever, this node won't be ever used for anything useful, just to have a proper rooted tree root.name = 'invalid' root.children = [] root.parent = None self.nodes[-1] = root self.children[-1] = [] self.parents[-1] = -1
[docs] def is_descendant_of(self, desc: int, ance: int) -> bool: """Checks whether the label `desc` is actually a descendant of the label `ance`.""" ance_mask = self.label_mask(ance) return (desc & ance_mask) == ance
[docs] def is_ancestor_of(self, ance: int, label: int) -> bool: """Checks whether the label `ance` is actually the ancestor of the label `label`.""" return self.is_descendant_of(label, ance)
[docs] def add_label(self, label: int, name: str, color: Tuple[int, int, int] = (255, 255, 255)): """Adds a new label to the hierarchy.""" # TODO check that `label` is not present already parent_label = self.get_parent(label) parent = None if parent_label == -1 else self.nodes[parent_label] self.children[label] = [] self.parents[label] = parent_label label_node = Node() label_node.label = label label_node.parent = parent label_node.name = name label_node.code = self.code(label) label_node.children = [] label_node.color = color if parent is not None: self.children[parent_label].append(label) parent.children.append(label_node) self.nodes[label] = label_node self.labels.append(label) self.labels.sort() self.colormap[label] = color
[docs] def add_child_label(self, parent: int, name: str, color: Tuple[int, int, int]) -> Node: """Adds a new child label to the label `parent`.""" last_child = max(self.children[parent], default=parent) if last_child != parent: level = self.get_level(last_child) else: level = self.get_level(parent) + 1 mask = self.masks[level] child_num = last_child one = 1 << self.shifts[level] child_num += one label = child_num parent_node: Node = self.nodes[parent] self.children[label] = [] self.parents[label] = parent label_node = Node() label_node.label = label label_node.parent = parent_node label_node.name = name label_node.code = self.code(label) label_node.children = [] label_node.color = color if parent_node is not None: self.children[parent].append(label) parent_node.children.append(label_node) self.nodes[label] = label_node self.labels.append(label) self.labels.sort() self.colormap[label] = color return label_node
@property def level_groups(self) -> List[Set[int]]: """ Returns sets of labels grouped by level, so `level_groups[1]` is a set of labels that are on level 1 in the hierarchy. """ return self._level_groups
[docs] def group_by_level(self, labels: Union[List[int], Set[int]]) -> Dict[int, Set[int]]: """ Groups the `labels` based on their level in the hierarchy. """ groups: Dict[int, Set[int]] = {} for label in labels: groups.setdefault(self.get_level(label), set()).add(label) return groups
[docs] def get_available_label(self, parent: int) -> int: """Returns the next available child label for `parent`.""" last_child = max(self.children[parent], default=0) mask = self.get_mask(last_child) child_num = last_child & mask one = 1 << self.shifts[self.get_level(last_child)] child_num += one return parent | child_num
@classmethod def _bit_count(cls, mask: int) -> int: """Returns the number of `1-bits` in `mask`.""" bit = mask & 1 while not bit: mask = mask >> 1 bit = mask & 1 return int(math.log2(mask + 1))
[docs] @classmethod def from_dict(cls, label_hier_info: Dict[str, Any]) -> 'LabelHierarchy': counts_of_bits = label_hier_info['counts_of_bits'] mask_names = label_hier_info['mask_names'] lab_hier = cls.designate_bits(counts_of_bits, mask_names=mask_names, n_bits=sum(counts_of_bits)) for label, label_dict in label_hier_info['labels'].items(): node = Node() node.label = int(label) node.code = label_dict['code'] node.name = label_dict['name'] color = label_dict['color'] node.color = (int(color['red']), int(color['green']), int(color['blue'])) lab_hier.colormap[node.label] = node.color lab_hier.nodes[node.label] = node lab_hier.labels.append(node.label) lab_hier.labels.sort() lab_hier.compute_children() lab_hier.name = label_hier_info['name'] if (mask_label := label_hier_info['constraint_mask_label']) is not None: lab_hier.mask_label = lab_hier.nodes[mask_label] return lab_hier
[docs] def to_dict(self) -> Dict[str, Any]: json_dict = { 'name': self.name, 'counts_of_bits': self.counts_of_bits, 'mask_names': self.mask_names, 'constraint_mask_label': None if self.mask_label is None else self.mask_label.label, 'labels': { label: node.to_dict() for label, node in self.nodes.items() } } return json_dict