# Copyright 2018-2021 Nick Anthony, Backman Biophotonics Lab, Northwestern University
#
# This file is part of mpl_qt_viz.
#
# mpl_qt_viz is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# mpl_qt_viz is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with mpl_qt_viz. If not, see <https://www.gnu.org/licenses/>.
from typing import Tuple, List, Optional
from PyQt5 import QtCore, QtGui
from PyQt5.QtCore import QSizeF, QTimer
from PyQt5.QtWidgets import QWidget, QGridLayout, QApplication, QPushButton, QGraphicsView, \
QGraphicsScene, QGroupBox, QVBoxLayout, QCheckBox, QButtonGroup, QMessageBox
from matplotlib import pyplot
from matplotlib.backends.backend_qt5 import NavigationToolbar2QT, FigureCanvasQT
import numpy as np
from mpl_qt_viz.roiSelection import LassoCreator, PointCreator, AdjustableSelector, AxManager
from ._canvas import PlotNdCanvas
from .._sharedWidgets import AnimationDlg, QRangeSlider
class _MyView(QGraphicsView):
"""This is a version of QGraphicsView that takes a Qt FigureCanvas from matplotlib and automatically resized the
canvas to fill as much of the view as possible. A debounce timer is used to prevent lag due to attempting the resize
the canvas too quickly. This allows for relatively smooth operation. This is essential for us to include a matplotlib
plot that can maintain it's aspect ratio within a Qt layout.
Args:
plot: A matplotlib FigureCanvas that is compatible with Qt (FigureCanvasQT or FigureCanvasQTAgg)
"""
def __init__(self, plot: FigureCanvasQT):
super().__init__()
self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
scene = QGraphicsScene(self)
scene.addWidget(plot)
self.plot = plot
self.setScene(scene)
self._debounce = QTimer() # If the resizeEvent doesn't reset the timer within the 50ms timeout interval then _resizePlot will be called.
self._debounce.setSingleShot(True)
self._debounce.setInterval(50)
self._debounce.timeout.connect(self._resizePlot)
def _resizePlot(self):
"""This method is indirectly called by the resizeEvent through the debounce timer and sets the size of the plot
to maximize its size without changing aspect ratio."""
w, h = self.size().width(), self.size().height()
r = self.scene().sceneRect()
s = min([w, h]) # Get the side length of the biggest square that can fit within the rectangle view area.
self.plot.resize(s, s) # Set the plot to the size of the square that fits in view.
r.setSize(QSizeF(s, s))
self.scene().setSceneRect(r) # Set the scene size to the square that fits in view.
def resizeEvent(self, event: QtGui.QResizeEvent):
"""Every time that the view is resized this event will fire and start the debounce timer. The timer will only
actually time out if this event doesn't restart it within the timeout period."""
self._debounce.start()
super().resizeEvent(event)
[docs]class PlotNd(QWidget): #TODO add function and GUI method to set coordinates of cursor.
"""A convenient widget for visualizing data that is 3D or greater. This is a standalone widget which extends the
functionality of `PlotNdCanvas`.
Args:
data: A 3D or greater numpy array of numeric values.
names: A sequence of labels for each axis of the data array.
initialCoords: An optional sequence of the coordinates to initially se the ND crosshair to. There should be one
coordinate for each axis of the data array.
title: A title for the window.
parent: The Qt Widget that serves as the parent for this widget.
indices: An optional tuple of 1d arrays of values to set as the indexes for each dimension of the data. Elements of the list can be set to `None` to skip
setting a custom index for that dimension.
Attributes:
data: A reference the the 3D or greater numpy array. This can be safely modified.
"""
_defaultNames = ('y', 'x', 'z', '4th', '5th', '6th', '7th', '8th', '9th', '10th', '11th')
def __init__(self, data: np.ndarray, names: Tuple[str, ...] = None,
initialCoords: Optional[Tuple[int, ...]] = None, title: Optional[str] = '',
parent: Optional[QWidget] = None, indices: List[np.ndarray] = None):
super().__init__(parent=parent)
self.setWindowTitle(str(title)) # Convert to string just in case
if names is None:
names = PlotNd._defaultNames[:len(data.shape)]
if data.dtype == bool:
data = data.astype(np.uint8)
self.canvas = PlotNdCanvas(data, names, initialCoords, indices)
self.view = _MyView(self.canvas)
self.slider = QRangeSlider(self)
self.slider.setMaximumHeight(20)
self.slider.setMax(np.nanmax(data))
self.slider.setMin(np.nanmin(data))
self.slider.setEnd(np.nanmax(data))
self.slider.setStart(np.nanmin(data))
_ = lambda: self.canvas.updateLimits(self.slider.end(), self.slider.start())
self.slider.startValueChanged.connect(_)
self.slider.endValueChanged.connect(_)
self._lastButton = None
self._axesManager = AxManager(self.canvas.image.ax)
self.selector = AdjustableSelector(self._axesManager, self.canvas.image.im, LassoCreator,
onfinished=self._selectorFinished)
self.buttonWidget = QGroupBox("Control", self)
self.buttonWidget.setLayout(QVBoxLayout())
check = QCheckBox("Cursor Active")
self.buttonWidget.layout().addWidget(check)
check.setChecked(self.canvas.spectraViewActive) # Get the right initial value
check.stateChanged.connect(lambda state: self.canvas.setSpectraViewActive(state != 0))
self.buttonGroup = QButtonGroup()
self.pointButton = QPushButton("Point")
self.buttonGroup.addButton(self.pointButton)
self.buttonWidget.layout().addWidget(self.pointButton)
self.lassoButton = QPushButton("Lasso")
self.buttonGroup.addButton(self.lassoButton)
self.buttonWidget.layout().addWidget(self.lassoButton)
self.noneButton = QPushButton("None")
self.buttonGroup.addButton(self.noneButton)
self.buttonWidget.layout().addWidget(self.noneButton)
for b in self.buttonGroup.buttons():
b.setCheckable(True)
self.noneButton.setChecked(True)
self.buttonGroup.buttonReleased.connect(self._handleButtons)
self.rotateButton = QPushButton("Rotate Axes")
self.rotateButton.released.connect(self.canvas.rollAxes)
self.saveButton = QPushButton("Save Animation")
self.saveButton.released.connect(self._saveAnimation)
layout = QGridLayout()
layout.addWidget(self.view, 0, 0, 8, 8)
layout.addWidget(self.buttonWidget, 0, 8, 4, 1)
layout.addWidget(self.rotateButton, 5, 8)
layout.addWidget(self.saveButton, 6, 8)
layout.addWidget(NavigationToolbar2QT(self.canvas, self), 10, 0, 1, 8)
layout.setRowStretch(0, 1)
layout.addWidget(self.slider, 8, 0, 1, 7)
self.setLayout(layout)
self.show()
def _saveAnimation(self):
def animationUpdaterFunc(z: int):
"""Used by the animation saver to iterate through the 3rd dimension of the data."""
self.canvas.coords = self.canvas.coords[:2] + (z,) + self.canvas.coords[3:]
self.canvas.updatePlots()
dlg = AnimationDlg(self.canvas.fig, (animationUpdaterFunc, range(self.canvas.data.shape[2])), self)
dlg.exec()
def _updateLimits(self):
""""""
self.canvas.updateLimits(self.slider.end(), self.slider.start())
def _handleButtons(self, button: QPushButton):
"""Acts as a callback when one of the ROI drawing buttons is pressed. Activates the corresponding ROI selector
Args:
button: The button that was just pressed.
"""
if button is self.pointButton and button is not self._lastButton:
self.selector.setSelector(PointCreator)
self.selector.setActive(True)
if button is self.lassoButton and button is not self._lastButton:
self.selector.setSelector(LassoCreator)
self.selector.setActive(True)
if button is self.noneButton and button is not self._lastButton:
self.selector.setActive(False)
self._lastButton = button
def _selectorFinished(self, verts: np.ndarray):
"""When an ROI selector finishes selecting a region the vertex coordinates of the selection are passed to this
function. The function then uses the vertices to plot the average of the data in the ROI"""
from pwspy.dataTypes import Roi
newVerts = []
for vert in verts: # Convert `verts` from being in terms of the values in self.canvas._indexes to being in terms of the element locations of the data array.
v1 = self.canvas.image.verticalValueToCoord(vert[1])
v0 = self.canvas.image.horizontalValueToCoord(vert[0])
newVerts.append((v0, v1))
verts = newVerts
roi = Roi.fromVerts('nomatter', 0, np.array(verts), self.canvas.data.shape[:2]) # A 2d ROI to select from the data
selected = self.canvas.data[roi.mask] # For a 3d data array this will now be 2d . For a 4d array it will be 3d etc. The 0th axis is one element for each selected pixel.
selected = selected.mean(axis=0) # Get the average over all selected pixels. We are now down to 1d for a 3d data array, 2d for a 4d data array, et.
if len(selected.shape) == 1:
fig, ax = pyplot.subplots()
ax.plot(self.canvas.indexes[2], selected)
ax.set_xlabel(self.canvas.names[2])
fig.show()
elif len(selected.shape) == 2:
fig, ax = pyplot.subplots()
im = ax.imshow(selected)
im.set_extent([self.canvas.indexes[3][0], self.canvas.indexes[3][-1], self.canvas.indexes[2][0], self.canvas.indexes[2][-1]])
ax.set_xlabel(self.canvas.names[3])
ax.set_ylabel(self.canvas.names[2])
fig.show()
else: # selected must be 3d or greater. This means our original data was 5d or greater.
p = PlotNd(selected, names=self.canvas.names[2:], indices=self.canvas.indexes[2:])
self.selector.setActive(True) # Reset the selector.
# API
@property
def data(self):
return self.canvas.data
@data.setter
def data(self, data: np.ndarray):
self.canvas.data = data
def setLimits(self, Min: float, Max: float):
return self.canvas.updateLimits(Max, Min)
def setColorMap(self, cmap):
self.canvas.setColorMap(cmap)
if __name__ == '__main__':
import sys
print("Starting")
x = np.linspace(0, 1, num=100)
y = np.linspace(0, 1, num=50)
z = np.linspace(0, 3, num=101)
# t = np.linspace(0, 1, num=3)
# c = np.linspace(12, 13, num=3)
X, Y, Z = np.meshgrid(x, y, z)
arr = np.sin(2 * np.pi * 1 * Z) + .5 * X + np.cos(2*np.pi*4*Y)# * T**1.5 * C*.1
app = QApplication(sys.argv)
p = PlotNd(arr[:,:,:], names=('y', 'x', 'z'), indices=[y, x, z]) # 3d
# p = PlotNd(arr[:,:,:,:,0], names=('y', 'x', 'z', 't'), indices=[y, x, z, t]) #4d
# p = PlotNd(arr, names=('y', 'x', 'z', 't', 'c'), indices=[y, x, z, t, c]) #5d
sys.exit(app.exec_())