# Copyright (c) [2024-2025] []
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
Visualization functions
=======================
.. currentmodule:: grogupy.viz.viz
.. autosummary::
:toctree: _generated/
plot_contour
plot_kspace
plot_magnetic_entities
plot_pairs
plot_DM
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ..physics.contour import Contour
from ..physics.kspace import Kspace
from ..physics.magnetic_entity import MagneticEntity
from ..physics.pair import Pair
import numpy as np
import plotly.graph_objs as go
[docs]
def plot_contour(contour: "Contour") -> go.Figure:
"""Creates a plot from the contour sample points.
Parameters
----------
contour : Contour
Contour class that contains the energy samples and weights
Returns
-------
plotly.graph_objs.go.Figure
The created figure
"""
# Create the scatter plot
fig = go.Figure(
data=go.Scatter(x=contour.samples.real, y=contour.samples.imag, mode="markers")
)
# Update the layout
fig.update_layout(
autosize=False,
width=1200,
height=700,
title="Energy contour integral",
xaxis_title="Real axis [eV]",
yaxis_title="Imaginary axis [eV]",
xaxis=dict(
showgrid=True,
gridwidth=1,
),
yaxis=dict(
showgrid=True,
gridwidth=1,
),
)
fig.update_yaxes(
scaleanchor="x",
scaleratio=1,
)
return fig
[docs]
def plot_kspace(kspace: "Kspace") -> go.Figure:
"""Creates a plot from the Brillouin zone sample points.
Parameters
----------
kspace : Kspace
Kspace class that contains the Brillouin-zone samples and weights
Returns
-------
plotly.graph_objs.go.Figure
The created figure
"""
# Create the scatter plot
# Create 3D scatter plot
trace = go.Scatter3d(
name=f"Kpoints",
x=kspace.kpoints[:, 0],
y=kspace.kpoints[:, 1],
z=kspace.kpoints[:, 2],
mode="markers",
marker=dict(
size=5,
color=kspace.weights,
colorscale="Viridis",
opacity=1,
colorbar=dict(title="Weights of kpoints", x=0.75),
),
)
# Update the layout
layout = go.Layout(
autosize=False,
width=1200,
height=700,
scene=dict(
aspectmode="data",
xaxis=dict(title="X Axis", showgrid=True, gridwidth=1),
yaxis=dict(title="Y Axis", showgrid=True, gridwidth=1),
zaxis=dict(title="Z Axis", showgrid=True, gridwidth=1),
),
)
# Create figure and show
fig = go.Figure(data=[trace], layout=layout)
return fig
[docs]
def plot_magnetic_entities(
magnetic_entities: list["MagneticEntity"],
) -> go.Figure:
"""Creates a plot from a list of magnetic entities.
Parameters
----------
magnetic_entities : list[MagneticEntity]
The magnetic entities that contain the tags and coordinates
Returns
-------
plotly.graph_objs.go.Figure
The created figure
"""
tags = [m.tag for m in magnetic_entities]
coords = [m.xyz for m in magnetic_entities]
colors = ["red", "green", "blue", "purple", "orange", "cyan", "magenta"]
colors = colors * (len(coords) // len(colors) + 1)
# Create figure
fig = go.Figure()
for coord, color, tag in zip(coords, colors, tags):
fig.add_trace(
go.Scatter3d(
name=tag,
x=coord[:, 0],
y=coord[:, 1],
z=coord[:, 2],
mode="markers",
marker=dict(size=10, opacity=0.8, color=color),
)
)
# Create layout
fig.update_layout(
autosize=False,
width=1200,
height=700,
scene=dict(
aspectmode="data",
xaxis=dict(title="X Axis", showgrid=True, gridwidth=1),
yaxis=dict(title="Y Axis", showgrid=True, gridwidth=1),
zaxis=dict(title="Z Axis", showgrid=True, gridwidth=1),
),
)
return fig
[docs]
def plot_pairs(pairs: list["Pair"], connect: bool = False) -> go.Figure:
"""Creates a plot from a list of pairs.
Parameters
----------
pairs : Union[list[Pair], None]
The pairs that contain the tags and coordinates
connect : bool, optional
Wether to connect the pairs or not, by default False
Returns
-------
plotly.graph_objs.go.Figure
The created figure
"""
centers = np.array([p.M1.xyz_center for p in pairs])
uniques = np.unique(centers, axis=0)
idx = []
for unique in uniques:
idx.append(np.where(np.isclose(centers, unique).all(axis=1))[0])
center_tags = np.array([p.tags[0] for p in pairs])
interacting_atoms = np.array(
[p.M2.xyz_center + p.supercell_shift_xyz for p in pairs]
)
interacting_tags = np.array(
[p.tags[1] + ", ruc:" + str(p.supercell_shift) for p in pairs]
)
colors = ["red", "green", "blue", "purple", "orange", "cyan", "magenta"]
colors = colors * (len(centers) // len(colors) + 1)
# Create figure
fig = go.Figure()
for i in range(len(idx)):
center = centers[idx[i][0]]
center_tag = center_tags[idx[i][0]]
color = colors[i]
# Create 3D scatter plot
fig.add_trace(
go.Scatter3d(
name="Center:" + center_tag,
x=[center[0]],
y=[center[1]],
z=[center[2]],
mode="markers",
marker=dict(size=10, opacity=0.8, color=color),
)
)
for interacting_atom, interacting_tag in zip(
interacting_atoms[idx[i]], interacting_tags[idx[i]]
):
legend_group = f"pair {center_tag}-{interacting_atom}"
fig.add_trace(
go.Scatter3d(
name=interacting_tag,
x=[interacting_atom[0]],
y=[interacting_atom[1]],
z=[interacting_atom[2]],
legendgroup=legend_group,
mode="markers",
marker=dict(size=5, opacity=0.5, color=color),
)
)
if connect:
fig.add_trace(
go.Scatter3d(
x=[center[0], interacting_atom[0]],
y=[center[1], interacting_atom[1]],
z=[center[2], interacting_atom[2]],
mode="lines",
legendgroup=legend_group,
showlegend=False,
line=dict(color=color),
)
)
# Create layout
fig.update_layout(
autosize=False,
width=1200,
height=700,
scene=dict(
aspectmode="data",
xaxis=dict(title="X Axis", showgrid=True, gridwidth=1),
yaxis=dict(title="Y Axis", showgrid=True, gridwidth=1),
zaxis=dict(title="Z Axis", showgrid=True, gridwidth=1),
),
)
return fig
[docs]
def plot_DM(pairs: list["Pair"], rescale: float = 1) -> go.Figure:
"""Creates a plot of the DM vectors from a list of pairs.
It can only use pairs from a finished simulation. The magnitude of
the vectors are in meV.
Parameters
----------
pairs : Union[list[Pair], None]
The pairs that contain the tags, coordinates and the DM vectors
rescale : float, optional
The length of the vectors are rescaled by this, by default 1
Returns
-------
plotly.graph_objs.go.Figure
The created figure
"""
# Define some example vectors
vectors = np.array([p.D_meV * rescale for p in pairs])
# Define origins (optional)
origins = np.array(
[(p.M1.xyz_center + p.M2.xyz_center + p.supercell_shift_xyz) / 2 for p in pairs]
)
n_vectors = len(vectors)
labels = ["-->".join(p.tags) + ", ruc:" + str(p.supercell_shift) for p in pairs]
colors = ["red", "green", "blue", "purple", "orange", "cyan", "magenta"]
colors = colors * (n_vectors // len(colors) + 1)
# Create figure
fig = go.Figure()
# Maximum vector magnitude for scaling
max_magnitude = max(np.linalg.norm(v) for v in vectors)
# Add each vector as a cone
for i, (vector, origin, label, color) in enumerate(
zip(vectors, origins, labels, colors)
):
# End point of the vector
end = origin + vector
legend_group = f"vector_{i}"
# Add a line for the vector
fig.add_trace(
go.Scatter3d(
x=[origin[0], end[0]],
y=[origin[1], end[1]],
z=[origin[2], end[2]],
mode="lines",
line=dict(color=color, width=5),
name=label,
legendgroup=legend_group,
showlegend=True,
)
)
# Add a cone at the end to represent the arrow head
u, v, w = vector
fig.add_trace(
go.Cone(
x=[end[0]],
y=[end[1]],
z=[end[2]],
u=[u / 5], # Scale down for better visualization
v=[v / 5],
w=[w / 5],
colorscale=[[0, color], [1, color]],
showscale=False,
sizemode="absolute",
sizeref=max_magnitude / 10,
legendgroup=legend_group,
showlegend=False,
)
)
# Set layout properties
# Create layout
fig.update_layout(
autosize=False,
width=1200,
height=700,
scene=dict(
aspectmode="data",
xaxis=dict(title="X Axis", showgrid=True, gridwidth=1),
yaxis=dict(title="Y Axis", showgrid=True, gridwidth=1),
zaxis=dict(title="Z Axis", showgrid=True, gridwidth=1),
),
)
return fig
if __name__ == "__main__":
pass