Source code for whitecanvas.backend.matplotlib.bars

from __future__ import annotations

import numpy as np
from numpy.typing import NDArray

from matplotlib.container import BarContainer
from matplotlib.patches import Rectangle
from whitecanvas.backend.matplotlib._base import MplLayer
from whitecanvas.protocols import BarProtocol, check_protocol
from whitecanvas.types import FacePattern, LineStyle
from whitecanvas.utils.normalize import as_color_array


[docs]@check_protocol(BarProtocol) class Bars(BarContainer, MplLayer): def __init__(self, xlow, xhigh, ylow, yhigh): patches = [] width = xhigh - xlow height = yhigh - ylow for x, y, dx, dy in zip(xlow, ylow, width, height): r = Rectangle(xy=(x, y), width=dx, height=dy, linestyle="-") r.get_path()._interpolation_steps = 100 patches.append(r) super().__init__(patches) self._visible = True def _plt_get_visible(self): return self._visible def _plt_set_visible(self, visible): for patch in self.patches: patch.set_visible(visible) self._visible = visible def _plt_set_zorder(self, zorder: int): for patch in self.patches: patch.set_zorder(zorder) ##### XXYYDataProtocol ##### def _plt_get_data(self): n = len(self.patches) x0 = np.empty(n, dtype=np.float32) x1 = np.empty(n, dtype=np.float32) y0 = np.empty(n, dtype=np.float32) y1 = np.empty(n, dtype=np.float32) for i, patch in enumerate(self.patches): x0[i] = patch.get_x() x1[i] = patch.get_x() + patch.get_width() y0[i] = patch.get_y() y1[i] = patch.get_y() + patch.get_height() return x0, x1, y0, y1 def _plt_set_data(self, x0, x1, y0, y1): for patch, x0i, x1i, y0i, y1i in zip(self.patches, x0, x1, y0, y1): patch.set_x(x0i) patch.set_width(x1i - x0i) patch.set_y(y0i) patch.set_height(y1i - y0i) ##### HasFace protocol ##### def _plt_get_face_color(self) -> NDArray[np.float32]: return np.stack([patch.get_facecolor() for patch in self.patches], axis=0) def _plt_set_face_color(self, color: NDArray[np.float32]): color = as_color_array(color, size=len(self.patches)) for patch, c in zip(self.patches, color): patch.set_facecolor(c) def _plt_get_face_pattern(self) -> list[FacePattern]: return [FacePattern(patch.get_hatch() or "") for patch in self.patches] def _plt_set_face_pattern(self, pattern: FacePattern | list[FacePattern]): if isinstance(pattern, FacePattern): pattern = [pattern] * len(self.patches) for pat, patch in zip(pattern, self.patches): patch.set_hatch(None if pat is FacePattern.SOLID else pat.value) ##### HasEdges protocol ##### def _plt_get_edge_color(self) -> NDArray[np.float32]: return np.stack([patch.get_edgecolor() for patch in self.patches], axis=0) def _plt_set_edge_color(self, color: NDArray[np.float32]): color = as_color_array(color, size=len(self.patches)) for patch, c in zip(self.patches, color): patch.set_edgecolor(c) def _plt_get_edge_style(self) -> list[LineStyle]: return [LineStyle(patch.get_linestyle()) for patch in self.patches] def _plt_set_edge_style(self, style: LineStyle | list[LineStyle]): if isinstance(style, LineStyle): style = [style] * len(self.patches) for patch, s in zip(self.patches, style): patch.set_linestyle(s.value) def _plt_get_edge_width(self) -> NDArray[np.floating]: return np.array([patch.get_linewidth() for patch in self.patches]) def _plt_set_edge_width(self, width: float | NDArray[np.floating]): if np.isscalar(width): width = [width] * len(self.patches) for patch, w in zip(self.patches, width): patch.set_linewidth(w)