from __future__ import annotations
import numpy as np
from numpy.typing import NDArray
from whitecanvas.protocols import MarkersProtocol, check_protocol
from whitecanvas.types import Symbol
from whitecanvas.utils.normalize import arr_color, as_color_array, rgba_str_color
from whitecanvas.backend import _not_implemented
from ._base import (
PlotlyLayer,
to_plotly_marker_symbol,
from_plotly_marker_symbol,
)
[docs]@check_protocol(MarkersProtocol)
class Markers(PlotlyLayer):
def __init__(self, xdata, ydata):
ndata = len(xdata)
self._props = {
"x": xdata,
"y": ydata,
"mode": "markers",
"marker": {
"color": ["blue"] * ndata,
"size": np.full(ndata, 10),
"symbol": "circle",
"line": {"width": np.ones(ndata), "color": ["blue"] * ndata},
},
"type": "scatter",
"showlegend": False,
"visible": True,
"customdata": list(zip([""] * ndata, [id(self)] * ndata)),
"hovertemplate": "%{customdata[0]}<extra></extra>",
}
self._fig_ref = lambda: None
self._click_callbacks = []
def _plt_get_ndata(self) -> int:
return len(self._props["x"])
def _plt_get_data(self):
return self._props["x"], self._props["y"]
def _plt_set_data(self, xdata, ydata):
self._props["x"] = xdata
self._props["y"] = ydata
def _plt_get_face_color(self) -> NDArray[np.float32]:
color = self._props["marker"]["color"]
if len(color) == 0:
return np.empty((0, 4), dtype=np.float32)
return np.stack([arr_color(c) for c in color], axis=0)
def _plt_set_face_color(self, color: NDArray[np.float32]):
color = as_color_array(color, self._plt_get_ndata())
self._props["marker"]["color"] = [rgba_str_color(c) for c in color]
_plt_get_face_pattern, _plt_set_face_pattern = _not_implemented.face_patterns()
def _plt_get_symbol(self) -> Symbol:
return from_plotly_marker_symbol(self._props["marker"]["symbol"])
def _plt_set_symbol(self, symbol: Symbol):
self._props["marker"]["symbol"] = to_plotly_marker_symbol(symbol)
def _plt_get_symbol_size(self) -> NDArray[np.floating]:
return np.asarray(self._props["marker"]["size"])
def _plt_set_symbol_size(self, size: float | NDArray[np.floating]):
if isinstance(size, (int, float, np.number)):
size = np.full(self._plt_get_ndata(), size)
self._props["marker"]["size"] = size
def _plt_get_edge_width(self) -> NDArray[np.floating]:
return np.asarray(self._props["marker"]["line"]["width"])
def _plt_set_edge_width(self, width: float):
if isinstance(width, (int, float, np.number)):
width = np.full(self._plt_get_ndata(), width)
self._props["marker"]["line"]["width"] = width
_plt_get_edge_style, _plt_set_edge_style = _not_implemented.edge_styles()
def _plt_get_edge_color(self) -> NDArray[np.float32]:
color = self._props["marker"]["line"]["color"]
if len(color) == 0:
return np.empty((0, 4), dtype=np.float32)
return np.stack([arr_color(c) for c in color], axis=0)
def _plt_set_edge_color(self, color: NDArray[np.float32]):
color = as_color_array(color, self._plt_get_ndata())
self._props["marker"]["line"]["color"] = [rgba_str_color(c) for c in color]
def _plt_connect_pick_event(self, callback):
fig = self._fig_ref()
if fig is None:
self._click_callbacks.append(callback)
return
else:
raise NotImplementedError("post connection not implemented yet")
def _plt_set_hover_text(self, text: list[str]):
fig = self._fig_ref()
if fig is None:
return
# check by ID
def selector(trace):
s = trace["customdata"]
if s is None:
return False
return s[0][1] == id(self)
fig.update_traces(
customdata=list(zip(text, [id(self)] * self._plt_get_ndata())),
selector=selector,
)