"""Batch objects are used for grouping other Shape and Batch objects.
"""
from typing import Any, Iterator, List, Sequence
from numpy import around, array
from typing_extensions import Self, Dict
import networkx as nx
from .all_enums import Types, batch_types, get_enum_value
from .common import common_properties, _set_Nones, Point, Line
from .core import Base
from .bbox import bounding_box
from ..canvas.style_map import batch_args
from ..helpers.validation import validate_args
from ..geometry.geometry import(
fix_degen_points,
get_polygons,
all_close_points,
mid_point,
distance,
connected_pairs,
round_segment,
round_point
)
from ..helpers.graph import is_cycle, is_open_walk, Graph
from ..settings.settings import defaults
from .merge import _merge_shapes, _merge_collinears
[docs]
class Batch(Base):
"""
A Batch object is a collection of other objects (Batch, Shape,
and Tag objects). It can be used to apply a transformation to
all the objects in the Batch. It is used for creating 1D and 2D
patterns of objects. all_vertices, all_elements, etc. means a flat
list of the specified object gathered recursively from all the
elements in the Batch.
"""
def __init__(
self,
elements: Sequence[Any] = None,
modifiers: Sequence["Modifier"] = None,
subtype: Types = Types.BATCH,
**kwargs,
):
"""
Initialize a Batch object.
Args:
elements (Sequence[Any], optional): The elements to include in the batch.
modifiers (Sequence[Modifier], optional): The modifiers to apply to the batch.
subtype (Types, optional): The subtype of the batch.
kwargs (dict): Additional keyword arguments.
"""
validate_args(kwargs, batch_args)
if elements and not isinstance(elements, (list, tuple)):
self.elements = [elements]
else:
self.elements = elements if elements is not None else []
self.type = Types.BATCH
if subtype not in batch_types:
raise ValueError(f"Invalid subtype '{subtype}' for a Batch object!")
self.subtype = get_enum_value(Types, subtype)
self.modifiers = modifiers
self.blend_mode = None
self.alpha = None
self.line_alpha = None
self.fill_alpha = None
self.text_alpha = None
self.clip = False # if clip is True, the batch.mask is used as a clip path
self.mask = None
self.even_odd_rule = False
self.blend_group = False
self.transparency_group = False
common_properties(self)
for key, value in kwargs.items():
setattr(self, key, value)
[docs]
def set_attribs(self, attrib, value):
"""
Sets the attribute to the given value for all elements in the batch if it is applicable.
Args:
attrib (str): The attribute to set.
value (Any): The value to set the attribute to.
"""
for element in self.elements:
if element.type == Types.BATCH:
setattr(element, attrib, value)
elif hasattr(element, attrib):
setattr(element, attrib, value)
[docs]
def set_batch_attr(self, attrib: str, value: Any) -> Self:
"""
Sets the attribute to the given value for the batch itself.
batch.attrib = value would set the attribute to the elements
of the batch object but not the batch itself.
Args:
attrib (str): The attribute to set.
value (Any): The value to set the attribute to.
Returns:
Self: The batch object.
"""
self.__dict__[attrib] = value
def __str__(self):
"""
Return a string representation of the batch.
Returns:
str: The string representation of the batch.
"""
if self.elements is None or len(self.elements) == 0:
res = "Batch()"
elif len(self.elements) in [1, 2]:
res = f"Batch({self.elements})"
else:
res = f"Batch({self.elements[0]}...{self.elements[-1]})"
return res
def __repr__(self):
"""
Return a string representation of the batch.
Returns:
str: The string representation of the batch.
"""
return self.__str__()
def __len__(self):
"""
Return the number of elements in the batch.
Returns:
int: The number of elements in the batch.
"""
return len(self.elements)
def __getitem__(self, subscript):
"""
Get the element(s) at the given subscript.
Args:
subscript (int or slice): The subscript to get the element(s) from.
Returns:
Any: The element(s) at the given subscript.
"""
if isinstance(subscript, slice):
res = self.elements[subscript.start : subscript.stop : subscript.step]
else:
res = self.elements[subscript]
return res
def __setitem__(self, subscript, value):
"""
Set the element(s) at the given subscript.
Args:
subscript (int or slice): The subscript to set the element(s) at.
value (Any): The value to set the element(s) to.
"""
elements = self.elements
if isinstance(subscript, slice):
elements[subscript.start : subscript.stop : subscript.step] = value
elif isinstance(subscript, int):
elements[subscript] = value
else:
raise TypeError("Invalid subscript type")
def __add__(self, other: "Batch") -> "Batch":
"""
Add another batch to this batch.
Args:
other (Batch): The other batch to add.
Returns:
Batch: The combined batch.
Raises:
RuntimeError: If the other object is not a batch.
"""
if other.type == Types.BATCH:
batch = self.copy()
for element in other.elements:
batch.append(element)
res = batch
else:
raise RuntimeError(
"Invalid object. Only Batch objects can be added together!"
)
return res
def __bool__(self):
"""
Return whether the batch has any elements.
Returns:
bool: True if the batch has elements, False otherwise.
"""
return len(self.elements) > 0
def __iter__(self):
"""
Return an iterator over the elements in the batch.
Returns:
Iterator: An iterator over the elements in the batch.
"""
return iter(self.elements)
def _duplicates(self, elements):
"""
Check for duplicate elements in the batch.
Args:
elements (Sequence[Any]): The elements to check for duplicates.
Raises:
ValueError: If duplicate elements are found.
Returns:
bool: True if duplicates are found, False otherwise.
"""
for element in elements:
ids = [x.id for x in self.elements]
if element.id in ids:
raise ValueError("Only unique elements are allowed!")
return len(set(elements)) != len(elements)
[docs]
def proximity(self, dist_tol: float = None, n: int = 5) -> list[Point]:
"""
Returns the n closest points in the batch.
Args:
dist_tol (float, optional): The distance tolerance for proximity.
n (int, optional): The number of closest points to return.
Returns:
list[Point]: The n closest points in the batch.
"""
if dist_tol is None:
dist_tol = defaults["dist_tol"]
vertices = self.all_vertices
vertices = [(*v, i) for i, v in enumerate(vertices)]
_, pairs = all_close_points(vertices, dist_tol=dist_tol, with_dist=True)
return [pair for pair in pairs if pair[2] > 0][:n]
[docs]
def append(self, element: Any) -> Self:
"""
Appends the element to the batch.
Args:
element (Any): The element to append.
Returns:
Self: The batch object.
"""
if element not in self.elements:
self.elements.append(element)
return self
[docs]
def reverse(self) -> Self:
"""
Reverses the order of the elements in the batch.
Returns:
Self: The batch object.
"""
self.elements = self.elements[::-1]
return self
[docs]
def insert(self, index, element: Any) -> Self:
"""
Inserts the element at the given index.
Args:
index (int): The index to insert the element at.
element (Any): The element to insert.
Returns:
Self: The batch object.
"""
if element not in self.elements:
self.elements.insert(index, element)
return self
[docs]
def remove(self, element: Any) -> Self:
"""
Removes the element from the batch.
Args:
element (Any): The element to remove.
Returns:
Self: The batch object.
"""
if element in self.elements:
self.elements.remove(element)
return self
[docs]
def pop(self, index: int) -> Any:
"""
Removes the element at the given index and returns it.
Args:
index (int): The index to remove the element from.
Returns:
Any: The removed element.
"""
return self.elements.pop(index)
[docs]
def clear(self) -> Self:
"""
Removes all elements from the batch.
Returns:
Self: The batch object.
"""
self.elements = []
return self
[docs]
def extend(self, elements: Sequence[Any]) -> Self:
"""
Extends the batch with the given elements.
Args:
elements (Sequence[Any]): The elements to extend the batch with.
Returns:
Self: The batch object.
"""
for element in elements:
if element not in self.elements:
self.elements.append(element)
return self
[docs]
def iter_elements(self, element_type: Types = None) -> Iterator:
"""Iterate over all elements in the batch, including the elements
in the nested batches.
Args:
element_type (Types, optional): The type of elements to iterate over. Defaults to None.
Returns:
Iterator: An iterator over the elements in the batch.
"""
for elem in self.elements:
if elem.type == Types.BATCH:
yield from elem.iter_elements(element_type)
else:
if element_type is None:
yield elem
elif elem.type == element_type:
yield elem
@property
def all_elements(self) -> list[Any]:
"""Return a list of all elements in the batch,
including the elements in the nested batches.
Returns:
list[Any]: A list of all elements in the batch.
"""
elements = []
for elem in self.elements:
if elem.type == Types.BATCH:
elements.extend(elem.all_elements)
else:
elements.append(elem)
return elements
@property
def all_shapes(self) -> list["Shape"]:
"""Return a list of all shapes in the batch.
Returns:
list[Shape]: A list of all shapes in the batch.
"""
elements = self.all_elements
shapes = []
for element in elements:
if element.type == Types.SHAPE:
shapes.append(element)
return shapes
@property
def all_vertices(self) -> list[Point]:
"""Return a list of all points in the batch in their
transformed positions.
Returns:
list[Point]: A list of all points in the batch in their transformed positions.
"""
elements = self.all_elements
vertices = []
for element in elements:
if element.type == Types.SHAPE:
vertices.extend(element.vertices)
elif element.type == Types.BATCH:
vertices.extend(element.all_vertices)
return vertices
@property
def all_segments(self) -> list[Line]:
"""Return a list of all segments in the batch.
Returns:
list[Line]: A list of all segments in the batch.
"""
elements = self.all_elements
segments = []
for element in elements:
if element.type == Types.SHAPE:
segments.extend(element.vertex_pairs)
return segments
def _get_graph_nodes_and_edges(self, dist_tol: float = None, n_round=None):
"""Get the graph nodes and edges for the batch.
Args:
dist_tol (float, optional): The distance tolerance for proximity. Defaults to None.
n_round (int, optional): The number of decimal places to round to. Defaults to None.
Returns:
tuple: A tuple containing the node coordinates and edges.
"""
if n_round is None:
n_round = defaults["n_round"]
_set_Nones(self, ["dist_tol", "n_round"], [dist_tol, n_round])
vertices = self.all_vertices
shapes = self.all_shapes
d_ind_coords = {}
point_id = []
rounded_vertices = []
for i, vert in enumerate(vertices):
coords = tuple(around(vert, n_round))
rounded_vertices.append(coords)
d_ind_coords[i] = coords
point_id.append([vert[0], vert[1], i])
_, pairs = all_close_points(point_id, dist_tol=dist_tol, with_dist=True)
for pair in pairs:
id1, id2, _ = pair
average = tuple(mid_point(vertices[id1], vertices[id2]))
d_ind_coords[id1] = average
d_ind_coords[id2] = average
rounded_vertices[id1] = average
rounded_vertices[id2] = average
d_coords_node_id = {}
d_node_id__rounded_coords = {}
s_rounded_vertices = set(rounded_vertices)
for i, vertex in enumerate(s_rounded_vertices):
d_coords_node_id[vertex] = i
d_node_id__rounded_coords[i] = vertex
edges = []
ind = 0
for shape in shapes:
node_ids = []
s_vertices = shape.vertices[:]
for vertex in s_vertices:
node_ids.append(d_coords_node_id[rounded_vertices[ind]])
ind += 1
edges.extend(connected_pairs(node_ids))
if shape.closed:
edges.append((node_ids[-1], node_ids[0]))
return d_node_id__rounded_coords, edges
[docs]
def as_graph(
self,
directed: bool = False,
weighted: bool = False,
dist_tol: float = None,
atol=None,
n_round: int = None,
) -> Graph:
"""Return the batch as a Graph object.
Graph.nx is the networkx graph.
Args:
directed (bool, optional): Whether the graph is directed. Defaults to False.
weighted (bool, optional): Whether the graph is weighted. Defaults to False.
dist_tol (float, optional): The distance tolerance for proximity. Defaults to None.
atol (optional): The absolute tolerance. Defaults to None.
n_round (int, optional): The number of decimal places to round to. Defaults to None.
Returns:
Graph: The batch as a Graph object.
"""
_set_Nones(self, ["dist_tol", "atol", "n_round"], [dist_tol, atol, n_round])
d_node_id_coords, edges = self._get_graph_nodes_and_edges(dist_tol, n_round)
if directed:
nx_graph = nx.DiGraph()
graph_type = Types.DIRECTED
else:
nx_graph = nx.Graph()
graph_type = Types.UNDIRECTED
for id_, coords in d_node_id_coords.items():
nx_graph.add_node(id_, pos=coords)
if weighted:
for edge in edges:
p1 = d_node_id_coords[edge[0]]
p2 = d_node_id_coords[edge[1]]
nx_graph.add_edge(edge[0], edge[1], weight=distance(p1, p2))
subtype = Types.WEIGHTED
else:
nx_graph.update(edges)
subtype = Types.NONE
graph = Graph(type=graph_type, subtype=subtype, nx_graph=nx_graph)
return graph
[docs]
def graph_summary(self, dist_tol: float = None, n_round: int = None) -> str:
"""Returns a representation of the Batch object as a graph.
Args:
dist_tol (float, optional): The distance tolerance for proximity. Defaults to None.
n_round (int, optional): The number of decimal places to round to. Defaults to None.
Returns:
str: A representation of the Batch object as a graph.
"""
if dist_tol is None:
dist_tol = defaults["dist_tol"]
if n_round is None:
n_round = defaults["n_round"]
all_shapes = self.all_shapes
all_vertices = self.all_vertices
lines = []
lines.append("Batch summary:")
lines.append(f"# shapes: {len(all_shapes)}")
lines.append(f"# vertices: {len(all_vertices)}")
for shape in self.all_shapes:
if shape.subtype:
s = (
f"# vertices in shape(id: {shape.id}, subtype: "
f"{shape.subtype}): {len(shape.vertices)}"
)
else:
s = f"# vertices in shape(id: {shape.id}): " f"{len(shape.vertices)}"
lines.append(s)
graph = self.as_graph(dist_tol=dist_tol, n_round=n_round).nx_graph
for island in nx.connected_components(graph):
lines.append(f"Island: {island}")
if is_cycle(graph, island):
lines.append(f"Cycle: {len(island)} nodes")
elif is_open_walk(graph, island):
lines.append(f"Open Walk: {len(island)} nodes")
else:
degens = [node for node in island if graph.degree(node) > 2]
degrees = f"{[(node, graph.degree(node)) for node in degens]}"
lines.append(f"Degenerate: {len(island)} nodes")
lines.append(f"(Node, Degree): {degrees}")
lines.append("-" * 40)
return "\n".join(lines)
def _merge_collinears(self, edges, n_round=2):
"""Merge collinear edges in the batch.
Args:
d_node_id_coords (dict): The node coordinates.
edges (list): The edges to merge.
tol (float, optional): The tolerance for merging. Defaults to None.
rtol (float, optional): The relative tolerance. Defaults to None.
atol (float, optional): The absolute tolerance. Defaults to None.
Returns:
list: The merged edges.
"""
return _merge_collinears(self, edges, n_round=n_round)
[docs]
def merge_shapes(
self, dist_tol: float = None, n_round: int = None) -> Self:
"""Merges the shapes in the batch if they are connected.
Returns a new batch with the merged shapes as well as the shapes
as well as the shapes that could not be merged.
Args:
tol (float, optional): The tolerance for merging shapes. Defaults to None.
rtol (float, optional): The relative tolerance. Defaults to None.
atol (float, optional): The absolute tolerance. Defaults to None.
Returns:
Self: The batch object with merged shapes.
"""
return _merge_shapes(self, dist_tol=dist_tol, n_round=n_round)
def _get_edges_and_segments(self, dist_tol: float = None, n_round: int = None):
"""Get the edges and segments for the batch.
Args:
dist_tol (float, optional): The distance tolerance for proximity. Defaults to None.
n_round (int, optional): The number of decimal places to round to. Defaults to None.
Returns:
tuple: A tuple containing the edges and segments.
"""
if dist_tol is None:
dist_tol = defaults["dist_tol"]
if n_round is None:
n_round = defaults["n_round"]
d_coord_node = self.d_coord_node
segments = self.all_segments
segments = [round_segment(segment, n_round) for segment in segments]
edges = []
for seg in segments:
p1, p2 = seg
id1 = d_coord_node[p1]
id2 = d_coord_node[p2]
edges.append((id1, id2))
return edges, segments
def _set_node_dictionaries(self, coords: List[Point], n_round: int=2) -> List[Dict]:
'''Set dictionaries for nodes and coordinates.
d_node_coord: Dictionary of node id to coordinates.
d_coord_node: Dictionary of coordinates to node id.
Args:
nodes (List[Point]): List of vertices.
n_round (int, optional): Number of rounding digits. Defaults to 2.
'''
coords = [tuple(round_point(coord, n_round)) for coord in coords]
coords = list(set(coords)) # remove duplicates
coords.sort() # sort by x coordinates
coords.sort(key=lambda x: x[1]) # sort by y coordinates
d_node_coord = {}
d_coord_node = {}
for i, coord in enumerate(coords):
d_node_coord[i] = coord
d_coord_node[coord] = i
self.d_node_coord = d_node_coord
self.d_coord_node = d_coord_node
[docs]
def all_polygons(self, dist_tol: float = None) -> list:
"""Return a list of all polygons in the batch in their
transformed positions.
Args:
dist_tol (float, optional): The distance tolerance for proximity. Defaults to None.
Returns:
list: A list of all polygons in the batch.
"""
if dist_tol is None:
dist_tol = defaults["dist_tol"]
exclude = []
include = []
for shape in self.all_shapes:
if len(shape.primary_points) > 2 and shape.closed:
vertices = shape.vertices
exclude.append(vertices)
else:
include.append(shape)
polylines = []
for element in include:
points = element.vertices
points = fix_degen_points(points, dist_tol=dist_tol, closed=element.closed)
polylines.append(points)
fixed_polylines = []
if polylines:
for polyline in polylines:
fixed_polylines.append(
fix_degen_points(polyline, dist_tol=dist_tol, closed=True)
)
polygons = get_polygons(fixed_polylines, dist_tol=dist_tol)
res = polygons + exclude
else:
res = exclude
return res
[docs]
def copy(self) -> "Batch":
"""Returns a copy of the batch.
Returns:
Batch: A copy of the batch.
"""
b = Batch(modifiers=self.modifiers)
if self.elements:
b.elements = [elem.copy() for elem in self.elements]
else:
b.elements = []
custom_attribs = custom_batch_attributes(self)
for attrib in custom_attribs:
setattr(b, attrib, getattr(self, attrib))
return b
@property
def b_box(self):
"""Returns the bounding box of the batch.
Returns:
BoundingBox: The bounding box of the batch.
"""
xy_list = []
for elem in self.elements:
xy_list.extend(
elem.b_box.corners
) # To do: we should eliminate this. Just add all points.
# To do: memoize the bounding box
return bounding_box(array(xy_list))
def _modify(self, modifier):
"""Apply a modifier to the batch.
Args:
modifier (Modifier): The modifier to apply.
"""
modifier.apply()
def _update(self, xform_matrix, reps: int = 0):
"""Updates the batch with the given transformation matrix.
If reps is 0, the transformation is applied to all elements.
If reps is greater than 0, the transformation creates
new elements with the transformed xform_matrix.
Args:
xform_matrix (ndarray): The transformation matrix.
reps (int, optional): The number of repetitions. Defaults to 0.
"""
if reps == 0:
for element in self.elements:
element._update(xform_matrix, reps=0)
if self.modifiers:
for modifier in self.modifiers:
modifier.apply(element)
else:
elements = self.elements[:]
new = []
for _ in range(reps):
for element in elements:
new_element = element.copy()
new_element._update(xform_matrix)
self.elements.append(new_element)
new.append(new_element)
if self.modifiers:
for modifier in self.modifiers:
modifier.apply(new_element)
elements = new[:]
new = []
return self
[docs]
def custom_batch_attributes(item: Batch) -> List[str]:
"""
Return a list of custom attributes of a Shape or
Batch instance.
Args:
item (Batch): The batch object.
Returns:
List[str]: A list of custom attributes.
"""
from .shape import Shape
if isinstance(item, Batch):
dummy_shape = Shape([(0, 0), (1, 0)])
dummy = Batch([dummy_shape])
else:
raise TypeError("Invalid item type")
native_attribs = set(dir(dummy))
custom_attribs = set(dir(item)) - native_attribs
return list(custom_attribs)