Source code for grogupy.viz.viz

# Copyright (c) [2024-2025] [Laszlo Oroszlany, Daniel Pozsar]
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Visualization functions
=======================

.. currentmodule:: grogupy.viz.viz

.. autosummary::
   :toctree: _generated/

   plot_contour
   plot_kspace
   plot_magnetic_entities
   plot_pairs
   plot_DMI

"""

from typing import TYPE_CHECKING, Union

from ..physics import Builder

if TYPE_CHECKING:
    from ..physics.contour import Contour
    from ..physics.kspace import Kspace
    from ..physics.magnetic_entity import MagneticEntity
    from ..physics.pair import Pair

import numpy as np
import plotly.graph_objs as go


[docs] def plot_contour(contour: "Contour") -> go.Figure: """Creates a plot from the contour sample points. Parameters ---------- contour : Contour Contour class that contains the energy samples and weights Returns ------- plotly.graph_objs.go.Figure The created figure """ # Create the scatter plot fig = go.Figure( data=go.Scatter(x=contour.samples.real, y=contour.samples.imag, mode="markers") ) # Update the layout fig.update_layout( autosize=False, width=800, height=500, title="Energy contour integral", xaxis_title="Real axis [eV]", yaxis_title="Imaginary axis [eV]", xaxis=dict( showgrid=True, gridwidth=1, ), yaxis=dict( showgrid=True, gridwidth=1, ), ) fig.update_yaxes( scaleanchor="x", scaleratio=1, ) return fig
[docs] def plot_kspace(kspace: "Kspace") -> go.Figure: """Creates a plot from the Brillouin zone sample points. Parameters ---------- kspace : Kspace Kspace class that contains the Brillouin-zone samples and weights Returns ------- plotly.graph_objs.go.Figure The created figure """ # Create the scatter plot # Create 3D scatter plot trace = go.Scatter3d( name=f"Kpoints", x=kspace.kpoints[:, 0], y=kspace.kpoints[:, 1], z=kspace.kpoints[:, 2], mode="markers", marker=dict( size=5, color=kspace.weights, colorscale="Viridis", opacity=1, colorbar=dict(title="Weights of kpoints", x=0.75), ), ) # Update the layout layout = go.Layout( autosize=False, title="Brillouin zone sampling", width=800, height=500, scene=dict( aspectmode="data", xaxis=dict(title="X Axis", showgrid=True, gridwidth=1), yaxis=dict(title="Y Axis", showgrid=True, gridwidth=1), zaxis=dict(title="Z Axis", showgrid=True, gridwidth=1), ), ) # Create figure and show fig = go.Figure(data=[trace], layout=layout) return fig
[docs] def plot_magnetic_entities( magnetic_entities: Union[Builder, list["MagneticEntity"]], ) -> go.Figure: """Creates a plot from a list of magnetic entities. Parameters ---------- magnetic_entities : Union[Builder, list["MagneticEntity"]] The magnetic entities that contain the tags and coordinates Returns ------- plotly.graph_objs.go.Figure The created figure """ # conversion line for the case when it is set as the plot function of a builder if isinstance(magnetic_entities, Builder): magnetic_entities = magnetic_entities.magnetic_entities elif not isinstance(magnetic_entities, list): magnetic_entities = [magnetic_entities] tags = [m.tag for m in magnetic_entities] coords = [m._xyz for m in magnetic_entities] colors = ["red", "green", "blue", "purple", "orange", "cyan", "magenta"] colors = colors * (len(coords) // len(colors) + 1) # Create figure fig = go.Figure() for coord, color, tag in zip(coords, colors, tags): fig.add_trace( go.Scatter3d( name=tag, x=coord[:, 0], y=coord[:, 1], z=coord[:, 2], mode="markers", marker=dict(size=10, opacity=0.8, color=color), ) ) # Create layout fig.update_layout( autosize=False, width=800, height=500, scene=dict( aspectmode="data", xaxis=dict(title="X Axis", showgrid=True, gridwidth=1), yaxis=dict(title="Y Axis", showgrid=True, gridwidth=1), zaxis=dict(title="Z Axis", showgrid=True, gridwidth=1), ), ) return fig
[docs] def plot_pairs(pairs: Union[Builder, list["Pair"]], connect: bool = False) -> go.Figure: """Creates a plot from a list of pairs. Parameters ---------- pairs : Union[Builder, list[Pair]] The pairs that contain the tags and coordinates connect : bool, optional Wether to connect the pairs or not, by default False Returns ------- plotly.graph_objs.go.Figure The created figure """ # conversion line for the case when it is set as the plot function of a builder if isinstance(pairs, Builder): pairs = pairs.pairs elif not isinstance(pairs, list): pairs = [pairs] # the centers can contain many atoms centers = [p.xyz[0] for p in pairs] # find unique centers uniques = [] def in_unique(c): for u in uniques: if c.shape == u.shape: if np.all(c == u): return True return False for c in centers: if not in_unique(c): uniques.append(c) # findex indexes for the same center idx = [[] for u in uniques] for i, u in enumerate(uniques): for j, c in enumerate(centers): if c.shape == u.shape: if np.all(c == u): idx[i].append(j) center_tags = np.array([p.tags[0] for p in pairs]) interacting_atoms = np.array([p.xyz[1] for p in pairs], dtype=object) interacting_tags = np.array( [p.tags[1] + ", ruc:" + str(p.supercell_shift) for p in pairs] ) colors = ["red", "green", "blue", "purple", "orange", "cyan", "magenta"] colors = colors * (len(centers) // len(colors) + 1) # Create figure fig = go.Figure() for i in range(len(idx)): center = centers[idx[i][0]] center_tag = center_tags[idx[i][0]] color = colors[i] # Create 3D scatter plot fig.add_trace( go.Scatter3d( name="Center:" + center_tag, x=center[:, 0], y=center[:, 1], z=center[:, 2], mode="markers", marker=dict(size=10, opacity=0.8, color=color), ) ) for interacting_atom, interacting_tag in zip( interacting_atoms[idx[i]], interacting_tags[idx[i]] ): legend_group = f"pair {center_tag}-{interacting_atom}" fig.add_trace( go.Scatter3d( name=interacting_tag, x=interacting_atom[:, 0], y=interacting_atom[:, 1], z=interacting_atom[:, 2], legendgroup=legend_group, mode="markers", marker=dict(size=5, opacity=0.5, color=color), ) ) if connect: fig.add_trace( go.Scatter3d( x=[center.mean(axis=0)[0], interacting_atom.mean(axis=0)[0]], y=[center.mean(axis=0)[1], interacting_atom.mean(axis=0)[1]], z=[center.mean(axis=0)[2], interacting_atom.mean(axis=0)[2]], mode="lines", legendgroup=legend_group, showlegend=False, line=dict(color=color), ) ) # Create layout fig.update_layout( autosize=False, width=800, height=500, scene=dict( aspectmode="data", xaxis=dict(title="X Axis", showgrid=True, gridwidth=1), yaxis=dict(title="Y Axis", showgrid=True, gridwidth=1), zaxis=dict(title="Z Axis", showgrid=True, gridwidth=1), ), ) return fig
[docs] def plot_DMI(pairs: Union[Builder, list["Pair"]], rescale: float = 1) -> go.Figure: """Creates a plot of the DM vectors from a list of pairs. It can only use pairs from a finished simulation. The magnitude of the vectors are in meV. Parameters ---------- pairs : Union[Builder, list["Pair"]] The pairs that contain the tags, coordinates and the DM vectors rescale : float, optional The length of the vectors are rescaled by this, by default 1 Returns ------- plotly.graph_objs.go.Figure The created figure """ # conversion line for the case when it is set as the plot function of a builder if isinstance(pairs, Builder): pairs = pairs.pairs elif not isinstance(pairs, list): pairs = [pairs] # Define some example vectors vectors = np.array([p.D_meV * rescale for p in pairs]) # Define origins (optional) origins = np.array( [(p.M1.xyz_center + p.M2.xyz_center + p.supercell_shift_xyz) / 2 for p in pairs] ) n_vectors = len(vectors) labels = ["-->".join(p.tags) + ", ruc:" + str(p.supercell_shift) for p in pairs] colors = ["red", "green", "blue", "purple", "orange", "cyan", "magenta"] colors = colors * (n_vectors // len(colors) + 1) # Create figure fig = go.Figure() # Maximum vector magnitude for scaling max_magnitude = max(np.linalg.norm(v) for v in vectors) # Add each vector as a cone for i, (vector, origin, label, color) in enumerate( zip(vectors, origins, labels, colors) ): # End point of the vector end = origin + vector legend_group = f"vector_{i}" # Add a line for the vector fig.add_trace( go.Scatter3d( x=[origin[0], end[0]], y=[origin[1], end[1]], z=[origin[2], end[2]], mode="lines", line=dict(color=color, width=5), name=label, legendgroup=legend_group, showlegend=True, ) ) # Add a cone at the end to represent the arrow head u, v, w = vector fig.add_trace( go.Cone( x=[end[0]], y=[end[1]], z=[end[2]], u=[u / 5], # Scale down for better visualization v=[v / 5], w=[w / 5], colorscale=[[0, color], [1, color]], showscale=False, sizemode="absolute", sizeref=max_magnitude / 10, legendgroup=legend_group, showlegend=False, ) ) # Set layout properties # Create layout fig.update_layout( autosize=False, width=800, height=500, scene=dict( aspectmode="data", xaxis=dict(title="X Axis", showgrid=True, gridwidth=1), yaxis=dict(title="Y Axis", showgrid=True, gridwidth=1), zaxis=dict(title="Z Axis", showgrid=True, gridwidth=1), ), ) return fig
if __name__ == "__main__": pass