Source code for whitecanvas.backend.plotly.canvas

import sys
from typing import Callable
import weakref
import numpy as np
from plotly import graph_objects as go

from whitecanvas import protocols
from whitecanvas.types import MouseEventType, MouseEvent
from whitecanvas.utils.normalize import rgba_str_color
from .markers import Markers
from ._base import PlotlyLayer
from ._labels import Title, AxisLabel, Axis, Ticks


[docs]class Canvas: def __init__(self, fig: go.FigureWidget | None = None, row: int = 0, col: int = 0): # prepare widget if fig is None: fig = go.FigureWidget() self._fig = fig self._xaxis = Axis(self, axis="x") self._yaxis = Axis(self, axis="y") self._xticks = Ticks(self, axis="x") self._yticks = Ticks(self, axis="y") self._title = Title(self) self._xlabel = AxisLabel(self, axis="x") self._ylabel = AxisLabel(self, axis="y") # add empty scatter just for click events self._scatter = go.Scatter( x=[], y=[], mode="markers", marker_opacity=0, showlegend=False ) self._fig.add_trace(self._scatter) self._row = row + 1 self._col = col + 1 def _plt_get_native(self): return self._fig def _plt_get_title(self): return self._title def _plt_get_xaxis(self): return self._xaxis def _plt_get_yaxis(self): return self._yaxis def _plt_get_xlabel(self): return self._xlabel def _plt_get_xticks(self): return self._xticks def _plt_get_yticks(self): return self._yticks def _plt_get_ylabel(self): return self._ylabel def _plt_reorder_layers(self, layers: list[PlotlyLayer]): model_to_idx_map = {id(layer._props): i for i, layer in enumerate(layers)} first, *data = self._fig._data self._fig._data = [first] + [data[model_to_idx_map[id(r)]] for r in data] def _plt_get_aspect_ratio(self) -> float | None: """Get aspect ratio of canvas""" try: locked = self._fig['layout']['yaxis']['scaleanchor'] == 'x' except KeyError: locked = False if locked: return 1 return None def _plt_set_aspect_ratio(self, ratio: float | None): """Set aspect ratio of canvas""" if ratio is None: self._fig['layout']['yaxis']['scaleanchor'] = None elif ratio == 1: self._fig['layout']['yaxis']['scaleanchor'] = 'x' else: raise NotImplementedError( f"Invalid aspect ratio for plotly backend: {ratio}" ) def _plt_add_layer(self, layer: PlotlyLayer): self._fig.add_trace(layer._props, row=self._row, col=self._col) layer._props = self._fig._data[-1] if isinstance(layer, Markers): gobj: go.Scatter = self._fig.data[-1] for cb in layer._click_callbacks: gobj.on_click(_convert_cb(cb), append=True) layer._fig_ref = weakref.ref(self._fig) def _plt_remove_layer(self, layer: PlotlyLayer): """Remove layer from the canvas""" self._fig._data.remove(layer._props) def _plt_get_visible(self) -> bool: """Get visibility of canvas""" return self._fig.layout.visibility == "visible" def _plt_set_visible(self, visible: bool): """Set visibility of canvas""" if visible: self._fig.layout.visibility = "visible" else: self._fig.layout.visibility = "hidden" def _plt_connect_xlim_changed(self, callback): self._fig.layout.on_change(lambda _, lim: callback(lim), 'xaxis.range') def _plt_connect_ylim_changed(self, callback): self._fig.layout.on_change(lambda _, lim: callback(lim), 'yaxis.range') def _plt_connect_mouse_click(self, callback: Callable[[MouseEvent], None]): """Connect callback to clicked event""" # TODO def _plt_connect_mouse_drag(self, callback: Callable[[MouseEvent], None]): """Connect callback to drag event""" # TODO def _plt_connect_mouse_double_click(self, callback: Callable[[MouseEvent], None]): """Connect callback to double-clicked event"""
# TODO def _convert_cb(cb): return lambda _, points, state: cb(points.point_inds)
[docs]@protocols.check_protocol(protocols.CanvasGridProtocol) class CanvasGrid: def __init__(self, heights: list[int], widths: list[int], app: str = "default"): from plotly.subplots import make_subplots self._figs = go.FigureWidget( make_subplots( rows=len(heights), cols=len(widths), row_heights=heights, column_widths=widths, ) ) self._figs.update_layout(margin=dict(l=6, r=6, t=6, b=6)) self._app = app def _plt_add_canvas(self, row: int, col: int, rowspan: int, colspan: int) -> Canvas: return Canvas(self._figs, row=row, col=col) def _plt_show(self): if self._app in ("qt", "wx", "tk"): return NotImplemented if self._app == "notebook" or "IPython" in sys.modules: from IPython.display import display from IPython import get_ipython if get_ipython().__class__.__name__ == "ZMQInteractiveShell": display(self._figs) return self._figs.show() def _plt_get_background_color(self): return self._figs.layout.paper_bgcolor def _plt_set_background_color(self, color): self._figs.layout.paper_bgcolor = rgba_str_color(color) def _plt_screenshot(self): from PIL import Image import io width, height = self._figs.layout.width, self._figs.layout.height img_bytes = self._figs.to_image(format="png", width=width, height=height) image = Image.open(io.BytesIO(img_bytes)) return np.asarray(image, dtype=np.uint8) def _plt_set_figsize(self, width: float, height: float): self._figs.layout.width = width self._figs.layout.height = height