from __future__ import annotations
from typing import Callable
from matplotlib.artist import Artist
import numpy as np
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.backend_bases import (
MouseEvent as mplMouseEvent,
MouseButton as mplMouseButton,
)
from matplotlib.lines import Line2D
from matplotlib.collections import Collection
from .bars import Bars
from .text import Texts as whitecanvasText
from .image import Image as whitecanvasImage
from ._labels import Title, XAxis, YAxis, XLabel, YLabel, XTicks, YTicks
from whitecanvas import protocols
from whitecanvas.types import MouseEvent, Modifier, MouseButton, MouseEventType
from whitecanvas.backend.matplotlib._base import MplLayer
[docs]@protocols.check_protocol(protocols.CanvasProtocol)
class Canvas:
def __init__(self, ax: plt.Axes | None = None):
if ax is None:
ax = plt.gca()
self._axes = ax
self._xaxis = XAxis(self)
self._yaxis = YAxis(self)
self._title = Title(self)
self._xlabel = XLabel(self)
self._ylabel = YLabel(self)
self._xticks = XTicks(self)
self._yticks = YTicks(self)
ax.set_axisbelow(True) # grid lines below other layers
self._annot = ax.annotate(
text="", xy=(0, 0), xytext=(20, -20), textcoords="offset points",
bbox=dict(fc="w"), fontproperties={"size": 14, "family": "Arial"},
) # fmt: skip
self._annot.set_visible(False)
def _set_tooltip(self, pos, text: str):
self._annot.xy = pos
self._annot.set_text(text)
self._annot.set_visible(True)
if fig := self._axes.get_figure():
fig.canvas.draw_idle()
def _hide_tooltip(self):
if self._annot.get_visible():
self._annot.set_visible(False)
if fig := self._axes.get_figure():
fig.canvas.draw_idle()
def _plt_get_native(self):
return self._axes
def _plt_get_title(self):
return self._title
def _plt_get_xaxis(self):
return self._xaxis
def _plt_get_yaxis(self):
return self._yaxis
def _plt_get_xlabel(self):
return self._xlabel
def _plt_get_ylabel(self):
return self._ylabel
def _plt_get_xticks(self):
return self._xticks
def _plt_get_yticks(self):
return self._yticks
def _plt_reorder_layers(self, layers: list[MplLayer]):
for i, layer in enumerate(layers):
layer._plt_set_zorder(i)
def _plt_get_aspect_ratio(self) -> float | None:
out = self._axes.get_aspect()
if out == "auto":
return None
return out
def _plt_set_aspect_ratio(self, ratio: float | None):
if ratio is None:
self._axes.set_aspect("auto")
else:
self._axes.set_aspect(ratio)
def _plt_add_layer(self, layer: Artist):
if isinstance(layer, Line2D):
self._axes.add_line(layer)
elif isinstance(layer, Collection):
self._axes.add_collection(layer, autolim=False)
elif isinstance(layer, Bars):
for child in layer.patches:
self._axes.add_patch(child)
self._axes.add_container(layer)
elif isinstance(layer, whitecanvasText):
layer.set_transform(self._axes.transData)
for child in layer._children:
self._axes._add_text(child)
elif isinstance(layer, whitecanvasImage):
self._axes.add_artist(layer)
else:
raise NotImplementedError(f"{layer}")
if hasattr(layer, "post_add"):
layer.post_add(self)
def _plt_remove_layer(self, layer: Artist):
"""Remove layer from the canvas"""
layer.remove()
def _plt_get_visible(self) -> bool:
"""Get visibility of canvas"""
return self._axes.get_visible()
def _plt_set_visible(self, visible: bool):
"""Set visibility of canvas"""
self._axes.set_visible(visible)
def _plt_connect_mouse_click(self, callback: Callable[[MouseEvent], None]):
"""Connect callback to clicked event"""
def _cb(ev: mplMouseEvent):
if ev.inaxes is not self._axes or ev.dblclick:
return
callback(self._translate_mouse_event(ev, MouseEventType.CLICK))
self._axes.figure.canvas.mpl_connect("button_press_event", _cb)
def _plt_connect_mouse_drag(self, callback: Callable[[MouseEvent], None]):
"""Connect callback to clicked event"""
def _cb(ev: mplMouseEvent):
if ev.inaxes is not self._axes or ev.dblclick:
return
callback(self._translate_mouse_event(ev, MouseEventType.MOVE))
self._axes.figure.canvas.mpl_connect("motion_notify_event", _cb)
def _plt_connect_mouse_double_click(self, callback: Callable[[MouseEvent], None]):
"""Connect callback to clicked event"""
def _cb(ev: mplMouseEvent):
if ev.inaxes is not self._axes or not ev.dblclick:
return
callback(self._translate_mouse_event(ev, MouseEventType.DOUBLE_CLICK))
self._axes.figure.canvas.mpl_connect("button_press_event", _cb)
def _translate_mouse_event(
self, ev: mplMouseEvent, typ: MouseEventType
) -> MouseEvent:
if ev.key is None:
modifiers = ()
else:
modifiers = []
for k in ev.key.split("+"):
if _MOUSE_MOD_MAP.get(k, None):
modifiers.append(_MOUSE_MOD_MAP.get(k, None))
modifiers = tuple(modifiers)
return MouseEvent(
pos=(ev.xdata, ev.ydata),
button=_MOUSE_BUTTON_MAP.get(ev.button, MouseButton.NONE),
modifiers=modifiers,
type=typ,
)
def _plt_connect_xlim_changed(
self, callback: Callable[[tuple[float, float]], None]
):
"""Connect callback to x-limits changed event"""
self._axes.callbacks.connect("xlim_changed", lambda ax: callback(ax.get_xlim()))
def _plt_connect_ylim_changed(
self, callback: Callable[[tuple[float, float]], None]
):
"""Connect callback to y-limits changed event"""
self._axes.callbacks.connect("ylim_changed", lambda ax: callback(ax.get_ylim()))
_MOUSE_BUTTON_MAP = {
mplMouseButton.LEFT: MouseButton.LEFT,
mplMouseButton.MIDDLE: MouseButton.MIDDLE,
mplMouseButton.RIGHT: MouseButton.RIGHT,
mplMouseButton.BACK: MouseButton.BACK,
mplMouseButton.FORWARD: MouseButton.FORWARD,
}
_MOUSE_MOD_MAP = {
"control": Modifier.CTRL,
"ctrl": Modifier.CTRL,
"shift": Modifier.SHIFT,
"alt": Modifier.ALT,
"meta": Modifier.META,
}
[docs]@protocols.check_protocol(protocols.CanvasGridProtocol)
class CanvasGrid:
def __init__(self, heights: list[int], widths: list[int], app: str = "default"):
nr, nc = len(heights), len(widths)
self._gridspec = plt.GridSpec(
nr, nc, height_ratios=heights, width_ratios=widths
)
if app == "qt":
app = "QtAgg"
elif app == "wx":
app = "WXAgg"
elif app == "gtk":
app = "GTK3Agg"
elif app == "tk":
app = "TkAgg"
elif app == "notebook":
app = "nbAgg"
if app != "default":
mpl.use(app)
self._fig = plt.figure()
def _plt_add_canvas(self, row: int, col: int, rowspan: int, colspan: int) -> Canvas:
r1 = row + rowspan
c1 = col + colspan
axes = self._fig.add_subplot(self._gridspec[row:r1, col:c1])
return Canvas(axes)
def _plt_get_visible(self) -> bool:
return self._fig.get_visible()
def _plt_show(self):
self._fig.show()
def _plt_get_background_color(self):
self._fig.get_facecolor()
def _plt_set_background_color(self, color):
self._fig.set_facecolor(color)
def _plt_screenshot(self):
import io
fig = self._fig
with io.BytesIO() as buff:
fig.savefig(buff, format="raw")
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
img = data.reshape((int(h), int(w), -1))
return img
def _plt_set_figsize(self, width: float, height: float):
dpi = self._fig.get_dpi()
self._fig.set_size_inches(width / dpi, height / dpi)
self._fig.tight_layout()