__author__ = "Pedram Tavadze and Logan Lang"
__maintainer__ = "Pedram Tavadze and Logan Lang"
__email__ = "petavazohi@mail.wvu.edu, lllang@mix.wvu.edu"
__date__ = "March 31, 2020"
import os
import yaml
import json
from typing import List
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pylab as plt
from matplotlib.collections import LineCollection
from matplotlib.ticker import MultipleLocator, FormatStrFormatter, AutoMinorLocator
from pyprocar.utils import ROOT,ConfigManager
from ..utils.defaults import settings
from ..core import Structure, DensityOfStates
np.seterr(divide="ignore", invalid="ignore")
[docs]
class DOSPlot:
"""
Class to plot an electronic band structure.
Parameters
----------
dos : DensityOfStates
An density of states pyprocar.core.DensityOfStates.
structure : Structure
An density of states pyprocar.core.Structure.
ax : mpl.axes.Axes, optional
A matplotlib Axes object. If provided the plot will be located at that ax.
The default is None.
Returns
-------
None.
"""
[docs]
def __init__(self,
dos:DensityOfStates=None,
structure:Structure=None,
ax:mpl.axes.Axes=None,
orientation:str='horizontal',
config=None):
self.config=config
self.dos = dos
self.structure = structure
self.handles = []
self.labels = []
self.orientation = orientation
self.values_dict={}
if ax is None:
self.fig = plt.figure(figsize=tuple(self.config.figure_size),)
self.ax = self.fig.add_subplot(111)
else:
self.fig = plt.gcf()
self.ax = ax
if self.orientation not in ['horizontal', 'vertical']:
raise ValueError(f"The orientation must be either horizontal or vertical, not {self.orientation}")
return None
[docs]
def plot_dos(self, spins: List[int] = None):
values_dict={}
spin_projections, spin_channels = self._get_spins_projections_and_channels(spins)
energies = self.dos.energies
dos_total = self.dos.total
self._set_plot_limits(spin_channels)
for ispin, spin_channel in enumerate(spin_channels):
# flip the sign of the total dos if there are 2 spin channels
dos_total_spin = dos_total[spin_channel, :] * (-1 if ispin > 0 else 1)
self._plot_total_dos(energies, dos_total_spin, spin_channel)
values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_total_spin
self.values_dict=values_dict
return values_dict
[docs]
def plot_parametric(self,
atoms: List[int] = None,
orbitals: List[int] = None,
spins: List[int] = None,
principal_q_numbers: List[int] = [-1]):
values_dict={}
spin_projections,spin_channels = self._get_spins_projections_and_channels(spins)
dos_total, dos_total_projected, dos_projected = self._calculate_parametric_dos(atoms,
orbitals,
spin_projections,
principal_q_numbers)
orbital_string=':'.join([str(orbital) for orbital in orbitals])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])
self._setup_colorbar(dos_projected, dos_total_projected)
self._set_plot_limits(spin_channels)
for ispin, spin_channel in enumerate(spin_channels):
energies, dos_spin_total, normalized_dos_spin_projected = self._prepare_parametric_spin_data(spin_channel,
ispin,
dos_total,
dos_projected,
dos_total_projected)
self._plot_spin_data_parametric(energies, dos_spin_total, normalized_dos_spin_projected)
if self.config.plot_total:
self._plot_total_dos(energies, dos_spin_total, spin_channel)
values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}'] =normalized_dos_spin_projected
self.values_dict=values_dict
return values_dict
[docs]
def plot_parametric_line(self,
atoms: List[int] = None,
orbitals: List[int] = None,
spins: List[int] = None,
principal_q_numbers: List[int] = [-1]):
values_dict={}
spin_projections,spin_channels = self._get_spins_projections_and_channels(spins)
dos_total, dos_total_projected, dos_projected = self._calculate_parametric_dos(atoms,
orbitals,
spin_projections,
principal_q_numbers)
orbital_string=':'.join([str(orbital) for orbital in orbitals])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])
self._setup_colorbar(dos_projected, dos_total_projected)
self._set_plot_limits(spin_channels)
for ispin, spin_channel in enumerate(spin_channels):
energies, dos_spin_total, normalized_dos_spin_projected = self._prepare_parametric_spin_data(spin_channel,
ispin,
dos_total,
dos_projected,
dos_total_projected)
self._plot_spin_data_parametric_line(energies, dos_spin_total, normalized_dos_spin_projected, spin_channel)
values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}']=normalized_dos_spin_projected
self.values_dict=values_dict
return values_dict
[docs]
def plot_stack_species(
self,
principal_q_numbers:List[int]=[-1],
orbitals:List[int]=None,
spins:List[int]=None,
overlay_mode:bool=False,
):
values_dict={}
spin_projections,spin_channels = self._get_spins_projections_and_channels(spins)
orbital_label=self._get_stack_species_labels(orbitals)
self._set_plot_limits(spin_channels)
bottom_value=0
for specie in range(len(self.structure.species)):
idx = (np.array(self.structure.atoms) == self.structure.species[specie])
atoms = list(np.where(idx)[0])
orbital_string=':'.join([str(orbital) for orbital in orbitals])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])
dos_total, dos_total_projected, dos_projected = self._calculate_parametric_dos(
atoms,
orbitals,
spin_projections,
principal_q_numbers)
color=self.config.colors[specie]
for ispin, spin_channel in enumerate(spin_channels):
energies, dos_spin_total, scaled_dos_spin_projected = self._prepare_parametric_spin_data(spin_channel,
ispin,
dos_total,
dos_projected,
dos_total_projected,
scale=True)
if overlay_mode:
handle=self._plot_spin_overlay(
energies,
scaled_dos_spin_projected,
spin_channel,
color)
else:
top_value,handle=self._plot_spin_stack(
energies,
scaled_dos_spin_projected,
bottom_value,
color)
bottom_value+=top_value
label=self.structure.species[specie] + orbital_label
values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}']=scaled_dos_spin_projected
self.handles.append(handle)
self.labels.append(label)
if self.config.plot_total:
total_values_dict=self.plot_dos(spin_channels)
self.values_dict=values_dict
return values_dict
[docs]
def plot_stack_orbitals(
self,
principal_q_numbers:List[int]=[-1],
atoms:List[int]=None,
spins:List[int]=None,
overlay_mode:bool=False,
):
values_dict={}
spin_projections,spin_channels = self._get_spins_projections_and_channels(spins)
atom_names, orb_names, orb_l=self._get_stack_orbitals_labels(atoms)
self._set_plot_limits(spin_channels)
bottom_value=0
for iorb in range(len(orb_l)):
orbital_string=':'.join([str(orbital) for orbital in orb_l[iorb]])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])
dos_total, dos_total_projected, dos_projected = self._calculate_parametric_dos(
atoms=atoms,
orbitals=orb_l[iorb],
spin_projections=spin_projections,
principal_q_numbers=principal_q_numbers)
color=self.config.colors[iorb]
for ispin, spin_channel in enumerate(spin_channels):
energies, dos_spin_total, scaled_dos_spin_projected = self._prepare_parametric_spin_data(spin_channel,
ispin,
dos_total,
dos_projected,
dos_total_projected,
scale=True)
if overlay_mode:
handle=self._plot_spin_overlay(
energies,
scaled_dos_spin_projected,
spin_channel,
color)
else:
top_value,handle=self._plot_spin_stack(
energies,
scaled_dos_spin_projected,
bottom_value,
color)
bottom_value+=top_value
label=atom_names + orb_names[iorb]# + self.config.spin_labels[ispin]
values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}']=scaled_dos_spin_projected
self.handles.append(handle)
self.labels.append(label)
if self.config.plot_total:
total_values_dict=self.plot_dos(spin_channels)
self.values_dict=values_dict
return values_dict
[docs]
def plot_stack(
self,
items:dict=None,
principal_q_numbers:List[int]=[-1],
spins:List[int]=None,
overlay_mode:bool=False,
):
values_dict={}
if len(items) is None:
print("""Please provide the stacking items in which you want
to plot, example : {'Sr':[1,2,3],'O':[4,5,6,7,8]}
will plot the stacked plots of p orbitals of Sr and
d orbitals of Oxygen.""")
spin_projections,spin_channels = self._get_spins_projections_and_channels(spins)
self._set_plot_limits(spin_channels)
# Defining color per specie
counter = 0
colors_dict = {}
for specie in items:
colors_dict[specie] = self.config.colors[counter]
counter += 1
bottom_value=0
for specie in items:
idx = np.array(self.structure.atoms) == specie
atoms = list(np.where(idx)[0])
orbitals = items[specie]
orbital_label=self._get_stack_labels(orbitals)
orbital_string=':'.join([str(orbital) for orbital in orbitals])
atom_string=':'.join([str(atom) for atom in atoms])
spin_string=':'.join([str(spin_projection) for spin_projection in spin_projections])
dos_total, dos_total_projected, dos_projected = self._calculate_parametric_dos(
atoms=atoms,
orbitals=orbitals,
spin_projections=spin_projections,
principal_q_numbers=principal_q_numbers)
color=colors_dict[specie]
for ispin, spin_channel in enumerate(spin_channels):
energies, dos_spin_total, scaled_dos_spin_projected = self._prepare_parametric_spin_data(spin_channel,
ispin,
dos_total,
dos_projected,
dos_total_projected,
scale=True)
if overlay_mode:
handle=self._plot_spin_overlay(
energies,
scaled_dos_spin_projected,
spin_channel,
color)
else:
top_value,handle=self._plot_spin_stack(
energies,
scaled_dos_spin_projected,
bottom_value,
color)
bottom_value+=top_value
label=specie + orbital_label
values_dict['energies']=energies
values_dict['dosTotalSpin-'+str(spin_channel)]=dos_spin_total
values_dict['spinChannel-'+str(spin_channel) +
f'_orbitals-{orbital_string}' +
f'_atoms-{atom_string}' +
f'_spinProjection-{spin_string}']=scaled_dos_spin_projected
self.handles.append(handle)
self.labels.append(label)
if self.config.plot_total:
total_values_dict=self.plot_dos(spin_channels)
self.values_dict=values_dict
return values_dict
def _calculate_parametric_dos(self, atoms, orbitals, spin_projections, principal_q_numbers):
dos_total = np.array(self.dos.total)
dos_total_projected = self.dos.dos_sum()
dos_projected = self.dos.dos_sum(atoms=atoms,
principal_q_numbers=principal_q_numbers,
orbitals=orbitals,
spins=spin_projections)
return dos_total, dos_total_projected, dos_projected
def _get_spins_projections_and_channels(self, spins):
"""
This function determines the spin channels and projections from the spins keywrod argument.
Parameters
----------
spins : list of int, optional
A list of spins, by default None
Returns
-------
spin_projections : list of int
A list of spin projections
spin_channels : list of int
A list of spin channels
"""
if self.dos.is_non_collinear:
spin_projections = spins if spins else [0, 1, 2]
spin_channels = [0]
else:
spin_channel_list = range(self.dos.n_spins)
spin_projections = spins if spins else spin_channel_list
spin_channels = spins if spins else spin_channel_list
return spin_projections, spin_channels
def _get_stack_species_labels(self, orbitals):
# This condition will depend on which orbital basis is being used.
if self.dos.is_non_collinear and len(self.dos.projected[0][0]) == 2 + 2 + 4 + 4 + 6:
spins = [0]
if orbitals:
print("The plot only considers orbitals", orbitals)
label = "-"
if sum([x in orbitals for x in [0,1]]) == 2:
label += "s-j=0.5"
if sum([x in orbitals for x in [2,3]]) == 2:
label += "p-j=0.5"
if sum([x in orbitals for x in [4,5,6,7]]) == 4:
label += "p-j=1.5"
if sum([x in orbitals for x in [8,9,10,11]]) == 4:
label += "d-j=1.5"
if sum([x in orbitals for x in [12,13,14,15,16,17]]) == 6:
label += "d-j=2.5"
else:
if len(self.dos.projected[0][0]) == 2 + 2 + 4 + 4 + 6:
label = "-spd-j=0.5,1.5,2.5"
else:
label = "-"
else:
if orbitals:
print("The plot only considers orbitals", orbitals)
label = "-"
if sum([x in orbitals for x in [0]]) == 1:
label += "s"
if sum([x in orbitals for x in [1, 2, 3]]) == 3:
label += "p"
if sum([x in orbitals for x in [4, 5, 6, 7, 8]]) == 5:
label += "d"
if sum([x in orbitals for x in [9, 10, 11, 12, 13, 14, 15]]) == 7:
label += "f"
else:
if len(self.dos.projected[0][0]) == 1 + 3 + 5:
label = "-spd"
elif len(self.dos.projected[0][0]) == 1 + 3 + 5 + 7:
label = "-spdf"
else:
label = "-"
return label
def _get_stack_orbitals_labels(self,atoms):
atom_names = ""
if atoms:
print(
"The plot only considers atoms",
np.array(self.structure.atoms)[atoms],
)
atom_names = ""
for ispc in np.unique(np.array(self.structure.atoms)[atoms]):
atom_names += ispc + "-"
all_atoms = ""
for ispc in np.unique(np.array(self.structure.atoms)):
all_atoms += ispc + "-"
if atom_names == all_atoms:
atom_names = ""
if self.dos.is_non_collinear and len(self.dos.projected[0][0]) == 2 + 2 + 4 + 4 + 6:
orb_names = ["s-j=0.5", "p-j=0.5", "p-j=1.5", "d-j=1.5", "d-j=2.5"]
orb_l = [[0,1], [2,3], [4, 5, 6, 7], [8,9,10,11], [12,13,14,15,16,17]]
elif len(self.dos.projected[0][0]) == 1 + 3 + 5:
orb_names = ["s", "p", "d"]
orb_l = [[0], [1, 2, 3], [4, 5, 6, 7, 8]]
elif len(self.dos.projected[0][0]) == 1 + 3 + 5 + 7:
orb_names = ["s", "p", "d", "f"]
orb_l = [[0], [1, 2, 3], [4, 5, 6, 7, 8], [9, 10, 11, 12, 13, 14, 15]]
return atom_names, orb_names, orb_l
def _get_stack_labels(self, orbitals):
if self.dos.is_non_collinear and len(self.dos.projected[0][0]) == 2 + 2 + 4 + 4 + 6:
if len(self.dos.projected[0][0]) == 2 + 2 + 4 + 4 + 6:
all_orbitals = "-spd-j=0.5,1.5,2.5"
else:
all_orbitals = "-"
else:
if len(self.dos.projected[0][0]) == (1 + 3 + 5):
all_orbitals = "spd"
elif len(self.dos.projected[0][0]) == (1 + 3 + 5 + 7):
all_orbitals = "spdf"
else:
all_orbitals = ""
label = "-"
# For coupled basis
if len(self.dos.projected[0][0]) == 2 + 2 + 4 + 4 + 6:
if sum([x in orbitals for x in [0,1]]) == 2:
label += "s-j=0.5"
if sum([x in orbitals for x in [2,3]]) == 2:
label += "p-j=0.5"
if sum([x in orbitals for x in [4,5,6,7]]) == 4:
label += "p-j=1.5"
if sum([x in orbitals for x in [8,9,10,11]]) == 4:
label += "d-j=1.5"
if sum([x in orbitals for x in [12,13,14,15,16,17]]) == 6:
label += "d-j=2.5"
if label == "-" + all_orbitals:
label = ""
# For uncoupled basis
else:
if sum([x in orbitals for x in [0]]) == 1:
label += "s"
if sum([x in orbitals for x in [1, 2, 3]]) == 3:
label += "p"
if sum([x in orbitals for x in [4, 5, 6, 7, 8]]) == 5:
label += "d"
if sum([x in orbitals
for x in [9, 10, 11, 12, 13, 14, 15]]) == 7:
label += "f"
if label == "-" + all_orbitals:
label = ""
return label
def _setup_colorbar(self, dos_projected, dos_total_projected):
vmin, vmax = self._get_color_limits(dos_projected, dos_total_projected)
cmap = mpl.cm.get_cmap(self.config.cmap)
if self.config.plot_bar:
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
cb = self.fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap), ax=self.ax)
cb.ax.tick_params(labelsize=self.config.colorbar_tick_labelsize)
cb.set_label(self.config.colorbar_title,
size=self.config.colorbar_title_size,
rotation=270,
labelpad=self.config.colorbar_title_padding)
def _get_color_limits(self, dos_projected, dos_total_projected):
if self.config.clim:
self.clim = self.config.clim
else:
self.clim = [0,0]
self.clim[0] = dos_projected.min() / dos_total_projected.max()
self.clim[1] = dos_projected.max() / dos_total_projected.max()
return self.clim
def _set_plot_limits(self, spin_channels):
total_max = 0
for ispin in range(len(spin_channels)):
tmp_max = self.dos.total[ispin].max()
if tmp_max > total_max:
total_max = tmp_max
if self.orientation == 'horizontal':
self.set_xlabel(self.config.x_label)
self.set_ylabel(self.config.y_label)
self.set_xlim([self.dos.energies.min(), self.dos.energies.max()])
self.set_ylim([-self.dos.total.max(), total_max] if len(spin_channels) == 2 else [0, total_max])
elif self.orientation == 'vertical':
self.set_xlabel(self.config.y_label)
self.set_ylabel(self.config.x_label)
self.set_xlim([-self.dos.total.max(), total_max] if len(spin_channels) == 2 else [0, total_max])
self.set_ylim([self.dos.energies.min(), self.dos.energies.max()])
def _prepare_parametric_spin_data(self, spin_channel, ispin,
dos_total, dos_projected, dos_total_projected, scale=False):
"""
Prepares the data for the parametric plot.
Parameters
----------
spin_channel : int
The spin channel being plotted
ispin : int
The index of the spin channel being plotted
dos_total : np.ndarray
The total density of states
dos_projected : np.ndarray
The projected density of states
dos_total_projected : np.ndarray
The projected total density of states
scale : bool, optional
Boolean to scale the projected density of states
Returns
-------
x : np.ndarray
The x values
y_total : np.ndarray
The total y values
y_projected : np.ndarray
The projected y values
"""
energies = self.dos.energies
dos_total = dos_total[spin_channel, :]
dos_projected = dos_projected[spin_channel, :]
dos_total_projected = dos_total_projected[spin_channel, :]
normalized_dos_projected = dos_projected / dos_total_projected
normalized_dos_projected=np.nan_to_num(normalized_dos_projected, 0)
if ispin > 0 and len(self.dos.total) > 1:
dos_total *= -1
dos_projected *= -1
dos_total_projected *= -1
if scale:
scaled_dos_projected=normalized_dos_projected*dos_total
final_dos_projected=scaled_dos_projected
threshold=max(abs(dos_total))+1
final_dos_projected[np.abs(final_dos_projected) > threshold] = 0
else:
final_dos_projected=normalized_dos_projected
return energies, dos_total, final_dos_projected
def _get_bar_color(self, values):
cmap = mpl.cm.get_cmap(self.config.cmap)
return [cmap(value) for value in values]
def _set_data_to_orientation(self, energies, dos_total):
if self.orientation == 'horizontal':
data = {'x': energies,
'y': dos_total,
'energies': energies,
'dos_value': dos_total,
'xlim': [energies.min(), energies.max()],
'ylim': [dos_total.min(), dos_total.max()],
'xlabel': self.config.x_label,
'ylabel': self.config.y_label,
'fill_func':self.ax.fill_between}
elif self.orientation == 'vertical':
data = {'x': dos_total,
'y': energies,
'energies': energies,
'dos_value': dos_total,
'xlim': [dos_total.min(), dos_total.max()],
'ylim': [energies.min(), energies.max()],
'xlabel': self.config.y_label,
'ylabel': self.config.x_label,
'fill_func':self.ax.fill_betweenx}
return data
def _plot_total_dos(self, energies, dos_total_spin, spin_channel):
"""
Plots the total DOS.
Parameters
----------
spin_channel : int
The spin channel being plotted
spins_index : int
The index of the spin channels being plotted. If spin index is 1,
then the spins dos is inverted on the axis.
Returns
-------
None
None
"""
data=self._set_data_to_orientation(energies, dos_total_spin)
self.ax.plot(data['x'], data['y'], color='black',
alpha=self.config.opacity[spin_channel],
linestyle=self.config.linestyle[spin_channel],
label=self.config.spin_labels[spin_channel],
linewidth=self.config.linewidth[spin_channel])
def _plot_spin_data_parametric(self, energies, dos_total, normalized_dos_projected):
bar_color=self._get_bar_color(normalized_dos_projected)
data=self._set_data_to_orientation(energies, dos_total)
self._plot_fill_between(
x=data['energies'],
y=data['dos_value'],
fill_func=data['fill_func'],
bar_color=bar_color)
def _plot_spin_data_parametric_line(self, energies,
dos_total_spin, normalized_dos_spin_projected, spin_channel):
data=self._set_data_to_orientation(energies, dos_total_spin)
points = np.array( [data['x'], data['y']]).T.reshape(-1, 1, 2)
# generates line segments. This is the reason for the offset of the points
segments = np.concatenate([points[:-1], points[1:]], axis=1)
norm = mpl.colors.Normalize(vmin=self.clim[0],
vmax=self.clim[1])
lc = LineCollection(segments, cmap=plt.get_cmap(self.config.cmap), norm=norm)
lc.set_array(normalized_dos_spin_projected)
lc.set_linewidth(self.config.linewidth[spin_channel])
lc.set_linestyle(self.config.linestyle[spin_channel])
handle = self.ax.add_collection(lc)
self.handles.append(handle)
def _plot_fill_between(self, x, y, fill_func, bottom_value=0, bar_color=None, color=None):
if color:
final_color=color
handle=fill_func(x, y + bottom_value, bottom_value, color=final_color)
if bar_color:
for i in range(len(x) - 1):
handle=fill_func([x[i], x[i+1]], [y[i], y[i+1]], color=bar_color[i])
return handle
def _plot_spin_stack(self,
energies,
scaled_projected_dos,
bottom_value=0,
color=None):
data=self._set_data_to_orientation(energies, scaled_projected_dos)
handle=self._plot_fill_between(
x=data['energies'],
y=data['dos_value'],
fill_func=data['fill_func'],
bottom_value=bottom_value,
color=color)
bottom_value=data['dos_value']
return bottom_value, handle
def _plot_spin_overlay(self,
energies,
scaled_projected_dos,
spin_channel,
color=None):
data=self._set_data_to_orientation(energies, scaled_projected_dos)
handle,=self.ax.plot(data['x'],data['y'],color=color,
alpha=self.config.opacity[spin_channel],
linestyle=self.config.linestyle[spin_channel],
label=self.config.spin_labels[spin_channel],
linewidth=self.config.linewidth[spin_channel])
return handle
[docs]
def set_xticks(self,
tick_positions:List[int]=None,
tick_names:List[str]=None):
"""A method to set the xticks of the plot
Parameters
----------
tick_positions : List[int], optional
A list of tick positions, by default None
tick_names : List[str], optional
A list of tick names, by default None
"""
if tick_positions is not None:
self.ax.set_xticks(tick_positions)
if tick_names is not None:
self.ax.set_xticklabels(tick_names)
return None
[docs]
def set_yticks(self,
tick_positions:List[int]=None,
tick_names:List[str]=None):
"""A method to set the yticks of the plot
Parameters
----------
tick_positions : List[int], optional
A list of tick positions, by default None
tick_names : List[str], optional
A list of tick names, by default None
"""
if tick_positions is not None:
self.ax.set_xticks(tick_positions)
if tick_names is not None:
self.ax.set_xticklabels(tick_names)
return None
[docs]
def set_xlim(self,
interval:List[int]=None):
"""A method to set the xlim of the plot
Parameters
----------
interval : List[int], optional
The x interval, by default None
"""
if interval is not None:
self.ax.set_xlim(interval)
return None
[docs]
def set_ylim(self,
interval:List[int]=None):
"""A method to set the ylim of the plot
Parameters
----------
interval : List[int], optional
The y interval, by default None
"""
if interval is not None:
self.ax.set_ylim(interval)
return None
[docs]
def set_xlabel(self, label:str):
"""A method to set the x label
Parameters
----------
label : str
The x label name
Returns
-------
None
None
"""
if self.config.x_label:
self.ax.set_xlabel(self.config.x_label)
else:
self.ax.set_xlabel(label)
return None
[docs]
def set_ylabel(self, label:str):
"""A method to set the y label
Parameters
----------
label : str
The y label name
Returns
-------
None
None
"""
if self.config.y_label:
self.ax.set_ylabel(self.config.y_label)
else:
self.ax.set_ylabel(label)
[docs]
def legend(self,
labels:List[str]=None):
"""A method to include the legend
Parameters
----------
label : str
The labels for the legend
Returns
-------
None
None
"""
if labels == None:
labels = self.labels
if self.config.legend and len(labels) != 0:
if len(self.handles) != len(labels):
raise ValueError(f"The number of labels and handles should be the same, currently there are {len(self.handles)} handles and {len(labels)} labels")
self.ax.legend(self.handles, labels)
return None
[docs]
def draw_fermi(self,value, orientation:str='horizontal'):
"""A method to draw the fermi surface
Parameters
----------
orientation : str, optional
Boolean to plot vertical or horizontal, by default 'horizontal'
color : str, optional
A color , by default "blue"
linestyle : str, optional
THe line style, by default "dotted"
linewidth : float, optional
The linewidth, by default 1
Returns
-------
None
None
"""
if orientation == 'horizontal':
self.ax.axvline(x=value, color=self.config.fermi_color,
linestyle=self.config.fermi_linestyle,
linewidth=self.config.fermi_linewidth)
elif orientation == 'vertical':
self.ax.axhline(y=value, color=self.config.fermi_color,
linestyle=self.config.fermi_linestyle,
linewidth=self.config.fermi_linewidth)
return None
[docs]
def grid(self):
"""A method to include a grid on the plot.
Returns
-------
None
None
"""
if self.config.grid:
self.ax.grid(
self.config.grid,
which=self.config.grid_which,
color=self.config.grid_color,
linestyle=self.config.grid_linestyle,
linewidth=self.config.grid_linewidth)
return None
[docs]
def show(self):
"""A method to show the plot
Returns
-------
None
None
"""
plt.show()
return None
[docs]
def save(self, filename:str='dos.pdf'
):
"""A method to save the plot
Parameters
----------
filename : str, optional
The filename, by default 'dos.pdf'
Returns
-------
None
None
"""
plt.savefig(filename,dpi=self.config.dpi, bbox_inches="tight")
plt.clf()
return None
[docs]
def update_config(self, config_dict):
for key,value in config_dict.items():
self.config[key]['value']=value
[docs]
def export_data(self,filename):
"""
This method will export the data to a csv file
Parameters
----------
filename : str
The file name to export the data to
Returns
-------
None
None
"""
possible_file_types=['csv','txt','json','dat']
file_type=filename.split('.')[-1]
if file_type not in possible_file_types:
raise ValueError(f"The file type must be {possible_file_types}")
if self.values_dict is None:
raise ValueError("The data has not been plotted yet")
column_names=list(self.values_dict.keys())
sorted_column_names=[None]*len(column_names)
index=0
for column_name in column_names:
if 'energies' in column_name.split('_')[0]:
sorted_column_names[index]=column_name
index+=1
for column_name in column_names:
if 'dosTotalSpin' in column_name.split('_')[0]:
sorted_column_names[index]=column_name
index+=1
for ispin in range(2):
for column_name in column_names:
if 'spinChannel-0' in column_name.split('_')[0] and ispin==0:
sorted_column_names[index]=column_name
index+=1
if 'spinChannel-1' in column_name.split('_')[0] and ispin==1:
sorted_column_names[index]=column_name
index+=1
column_names.sort()
if file_type=='csv':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, index=False)
elif file_type=='txt':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, sep='\t', index=False)
elif file_type=='json':
with open(filename, 'w') as outfile:
for key,value in self.values_dict.items():
self.values_dict[key]=value.tolist()
json.dump(self.values_dict, outfile)
elif file_type=='dat':
df=pd.DataFrame(self.values_dict)
df.to_csv(filename, columns=sorted_column_names, sep=' ', index=False)