Source code for whitecanvas.backend.matplotlib.markers

from __future__ import annotations
import weakref
from typing import TYPE_CHECKING

import numpy as np
from numpy.typing import NDArray

import matplotlib.pyplot as plt
from matplotlib.collections import PathCollection
import matplotlib.markers as mmarkers
import matplotlib.transforms as mtransforms

from psygnal import throttled
from whitecanvas.backend.matplotlib._base import MplLayer
from whitecanvas.protocols import MarkersProtocol, check_protocol
from whitecanvas.types import Symbol, FacePattern, LineStyle
from whitecanvas.utils.normalize import as_color_array

if TYPE_CHECKING:
    from whitecanvas.backend.matplotlib.canvas import Canvas
    from matplotlib.backend_bases import MouseEvent as mplMouseEvent


def _get_path(symbol: Symbol):
    marker_obj = mmarkers.MarkerStyle(symbol.value)
    return marker_obj.get_path().transformed(marker_obj.get_transform())


[docs]@check_protocol(MarkersProtocol) class Markers(PathCollection, MplLayer): def __init__(self, xdata, ydata): offsets = np.stack([xdata, ydata], axis=1) self._symbol = Symbol.CIRCLE super().__init__( (_get_path(self._symbol),), sizes=[6] * len(offsets), offsets=offsets, offset_transform=plt.gca().transData, ) self.set_transform(mtransforms.IdentityTransform()) self._edge_styles = [LineStyle.SOLID] * len(offsets) self._pick_callbacks = [] self._hover_texts: list[str] | None = None self._canvas = lambda: None ##### XYDataProtocol ##### def _plt_get_data(self): offsets = self.get_offsets() return offsets[:, 0], offsets[:, 1] def _plt_set_data(self, xdata, ydata): data = np.stack([xdata, ydata], axis=1) self.set_offsets(data) ##### HasSymbol protocol ##### def _plt_get_symbol(self) -> Symbol: return self._symbol def _plt_set_symbol(self, symbol: Symbol): path = _get_path(symbol) self.set_paths([path] * len(self.get_offsets())) self._symbol = symbol def _plt_get_symbol_size(self) -> NDArray[np.floating]: return np.sqrt(self.get_sizes()) def _plt_set_symbol_size(self, size: float | NDArray[np.floating]): if isinstance(size, (int, float, np.number)): size = np.full(len(self.get_offsets()), size) self.set_sizes(size**2) ##### HasFaces protocol ##### def _plt_get_face_color(self) -> NDArray[np.float32]: return self.get_facecolor() def _plt_set_face_color(self, color: NDArray[np.float32]): color = as_color_array(color, len(self.get_offsets())) self.set_facecolor(color) def _plt_get_face_pattern(self) -> FacePattern: return [FacePattern(self.get_hatch() or "")] * len(self.get_offsets()) def _plt_set_face_pattern(self, pattern: FacePattern | list[FacePattern]): if not isinstance(pattern, FacePattern): pattern = pattern[0] if pattern is FacePattern.SOLID: ptn = None else: ptn = pattern.value self.set_hatch(ptn) ##### HasEdges protocol ##### def _plt_get_edge_color(self) -> NDArray[np.float32]: return self.get_edgecolor() def _plt_set_edge_color(self, color: NDArray[np.float32]): color = as_color_array(color, len(self.get_offsets())) self.set_edgecolor(color) def _plt_get_edge_style(self) -> list[LineStyle]: return self._edge_styles def _plt_set_edge_style(self, style: LineStyle | list[LineStyle]): if isinstance(style, LineStyle): styles = [style.value] * len(self.get_offsets()) else: styles = [s.value for s in style] self.set_linestyle(styles) self._edge_styles = styles def _plt_get_edge_width(self) -> NDArray[np.floating]: return self.get_linewidth() def _plt_set_edge_width(self, width: float | NDArray[np.floating]): if isinstance(width, (int, float, np.number)): width = np.full(len(self.get_offsets()), width) self.set_linewidth(width) def _plt_connect_pick_event(self, callback): def cb(event): if event.artist is not self: return callback(event.ind) self._pick_callbacks.append(cb) def _plt_set_hover_text(self, texts: list[str]): self._hover_texts = texts
[docs] def post_add(self, canvas: Canvas): fig = self.get_figure() for cb in self._pick_callbacks: fig.canvas.mpl_connect("pick_event", cb) self._canvas = weakref.ref(canvas) fig.canvas.mpl_connect( "motion_notify_event", throttled(self._on_hover, timeout=500) )
def _on_hover(self, event: mplMouseEvent = None): if self._hover_texts is None or not self._plt_get_visible(): return canvas = self._canvas() if canvas is None: return if event.inaxes is not canvas._axes: return contains, ind = self.contains(event) if not contains: canvas._hide_tooltip() return indices = ind["ind"] if len(indices) == 0: canvas._hide_tooltip() return index: int = indices[0] hover_text = self._hover_texts[index] xy = self.get_offsets()[index] canvas._set_tooltip(xy, hover_text)