Source code for whitecanvas.layers.primitive.image

from __future__ import annotations

import numpy as np
from numpy.typing import ArrayLike, NDArray
from cmap import Colormap
from whitecanvas.protocols import ImageProtocol
from whitecanvas.types import ColormapType, _Void
from whitecanvas.backend import Backend
from whitecanvas.layers._base import PrimitiveLayer

_void = _Void()


[docs]class Image(PrimitiveLayer[ImageProtocol]): """ Grayscale or RGBA image layer. Parameters ---------- image : array_like 2D or 3D array of image data. cmap : colormap type, default is "gray" Colormap to use. clim : tuple of float or None, optional Contrast limits. If ``None``, the limits are set to the min and max of the data. You can also pass ``None`` separately to either limit to only autoscale one of them. """ def __init__( self, image: ArrayLike, *, name: str | None = None, cmap: ColormapType = "gray", clim: tuple[float | None, float | None] | None = None, shift: tuple[float, float] = (0, 0), scale: tuple[float, float] = (1.0, 1.0), backend: Backend | str | None = None, ): img = _normalize_image(image) super().__init__(name=name) self._backend = self._create_backend(Backend(backend), img) if img.ndim == 2: cmap = clim = _void self.update(cmap=cmap, clim=clim, shift=shift, scale=scale) self._x_hint, self._y_hint = _hint_for((img.shape[1], img.shape[0])) @property def data(self) -> NDArray[np.number]: """Current data of the layer.""" return self._backend._plt_get_data() @data.setter def data(self, data: ArrayLike): self._backend._plt_set_data(_normalize_image(data)) @property def cmap(self) -> Colormap: """Current colormap.""" return self._backend._plt_get_colormap() @cmap.setter def cmap(self, cmap: ColormapType): self._backend._plt_set_colormap(Colormap(cmap)) @property def clim(self) -> tuple[float, float]: """Current contrast limits.""" return self._backend._plt_get_clim() @clim.setter def clim(self, clim: tuple[float | None, float | None] | None): if clim is None: low, high = None, None else: low, high = clim if low is None: low = self.data.min() if high is None: high = self.data.max() self._backend._plt_set_clim((low, high)) @property def shift(self) -> tuple[float, float]: """Current shift from the origin.""" return self._backend._plt_get_translation() @shift.setter def shift(self, shift: tuple[float, float]): self._backend._plt_set_translation(shift) img = self.data self._x_hint, self._y_hint = _hint_for( (img.shape[1], img.shape[0]), shift=shift, scale=self.scale ) @property def scale(self) -> tuple[float, float]: """Current scale.""" return self._backend._plt_get_scale() @scale.setter def scale(self, scale: float | tuple[float, float]): if isinstance(scale, (int, float, np.number)): scale = float(scale), float(scale) dx, dy = scale if dx <= 0 or dy <= 0: raise ValueError("Scale must be positive.") self._backend._plt_set_scale(scale) img = self.data self._x_hint, self._y_hint = _hint_for( (img.shape[1], img.shape[0]), shift=self.shift, scale=scale )
[docs] def update( self, *, cmap: ColormapType | _void = _void, clim: tuple[float | None, float | None] | None | _Void = _void, shift: tuple[float, float] | _void = _void, scale: tuple[float, float] | _void = _void, ) -> Image: if cmap is not _void: if self.is_rgba: raise ValueError("Cannot set colormap for an RGBA image.") self.cmap = cmap if clim is not _void: if self.is_rgba: raise ValueError("Cannot set contrast limits for an RGBA image.") self.clim = clim if shift is not _void: self.shift = shift if scale is not _void: self.scale = scale return self
[docs] def fit_to(self, bbox: tuple[float, float, float, float]) -> Image: """Fit the image to the given bounding box.""" x0, y0, x1, y1 = bbox dx, dy = x1 - x0, y1 - y0 self.shift = (x0 + 0.5, y0 + 0.5) self.scale = (dx / self.data.shape[0], dy / self.data.shape[1]) return self
@property def is_rgba(self) -> bool: """Whether the image is RGBA.""" return self.data.ndim == 3
def _normalize_image(image): img = np.asarray(image) if img.dtype.kind not in "uif": raise TypeError(f"Only numerical arrays are allowed, got {img.dtype}") # check shape if img.ndim == 2: pass elif img.ndim == 3: nchannels = img.shape[2] if nchannels not in (3, 4): raise ValueError( "If 3D array is given, the last dimension must be 3 or 4, " f"got shape {img.shape}." ) else: raise ValueError(f"Only 2D or 3D arrays are allowed, got {img.ndim}") return img def _hint_for( shape: tuple[int, int], shift: tuple[float, float] = (0, 0), scale: tuple[float, float] = (1, 1), ) -> tuple[float, float]: xhint = np.array([-0.5, shape[0] - 0.5]) * scale[0] + shift[0] yhint = np.array([-0.5, shape[1] - 0.5]) * scale[1] + shift[1] return tuple(xhint - 0.5), tuple(yhint - 0.5)