# Copyright 2018-2021 Nick Anthony, Backman Biophotonics Lab, Northwestern University
#
# This file is part of mpl_qt_viz.
#
# mpl_qt_viz is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# mpl_qt_viz is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with mpl_qt_viz. If not, see <https://www.gnu.org/licenses/>.
from typing import List
from PyQt5 import QtCore
from PyQt5.QtWidgets import QWidget, QGridLayout, QPushButton, QApplication
import numpy as np
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.backends.backend_qt5 import NavigationToolbar2QT
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
from matplotlib.figure import Figure
import matplotlib.pyplot as plt
from ._sharedWidgets import AnimationDlg
[docs]class MultiPlot(QWidget):
"""
A widget that allows the user to flip through a set of matplotlib artists (images, plots, etc.)
Args:
artists: A list of lists of matplotlib 'Artists`. each list will comprise a single frame, just like the matplotlib `ArtistAnimation` works.
title (str): The name for the title of the window
"""
def __init__(self, artists: List[List[Artist]], title: str, parent=None):
QWidget.__init__(self, parent=parent, flags=QtCore.Qt.Window)
self.setWindowTitle(title)
layout = QGridLayout()
self._artists = artists
self.figure: Figure = artists[0][0].figure # We are assuming that all artists use the same figure.
self.ax: Axes = self._artists[0][0].axes
self._canvas = FigureCanvasQTAgg(self.figure)
self._canvas.setFocusPolicy(QtCore.Qt.ClickFocus) #Not sure what this is for
self._canvas.setFocus()
self.previousButton = QPushButton('←')
self.nextButton = QPushButton('→')
self.previousButton.released.connect(self.showPreviousIm)
self.nextButton.released.connect(self.showNextIm)
self.saveButton = QPushButton("Save Animation")
self.saveButton.released.connect(lambda: AnimationDlg(self.figure, self._artists, self).exec())
layout.addWidget(self._canvas, 0, 0, 8, 8)
layout.addWidget(self.previousButton, 9, 1, 1, 1)
layout.addWidget(self.nextButton, 9, 2, 1, 1)
layout.addWidget(self.saveButton, 9, 7, 1, 1)
layout.addWidget(NavigationToolbar2QT(self._canvas, self), 10, 0, 1, 4)
layout.setRowStretch(0, 1) # This causes the plot to take up all the space that isn't needed by the other widgets.
self.setLayout(layout)
self._index = 0 # Keeps track of which frame we are currently viewing.
self._updateDisplayedImage()
plt.close(self.figure) # A separate figure window will have been opened during the creation of our artists. This is extra and should be closed.
[docs] def showPreviousIm(self):
"""Display the previous set of display elements."""
self._index -= 1
if self._index < 0:
self._index = len(self._artists) - 1
self._updateDisplayedImage()
[docs] def showNextIm(self):
"""Display the next set of display elements."""
self._index += 1
if self._index >= len(self._artists):
self._index = 0
self._updateDisplayedImage()
[docs] def imshow(self, *args, **kwargs):
"""Mirrors the pyplot.imshow function. Adds a new image to the set of images shown by this widget."""
self._artists.append([self.ax.imshow(*args, **kwargs)])
self._index = len(self._artists) - 1
self._updateDisplayedImage()
def _updateDisplayedImage(self):
for i, frame in enumerate(self._artists):
for artist in frame:
artist.set_visible(self._index == i)
self._canvas.draw_idle()
if __name__ == '__main__':
import sys
import matplotlib.pyplot as plt
app = QApplication(sys.argv)
sh = (1024, 1024)
ims = [[plt.imshow(np.random.random(sh)), plt.text(100, 100, str(i))] for i in range(3)]
mp = MultiPlot(ims, "Hey")
plt.gcf().subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
mp.ax.get_xaxis().set_visible(False)
mp.ax.get_yaxis().set_visible(False)
[mp.imshow(np.random.random(sh)) for i in range(3)]
mp.show()
fig, ax = plt.subplots()
lines = [ax.plot(np.random.random((50,))) for i in range(3)]
mp2 = MultiPlot(lines, 'Lines')
mp2.show()
sys.exit(app.exec())