import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import diabfunct
import symm_vcham
from typing import List, Tuple
from logging_config import get_logger
logger = get_logger(__name__)
"""
This module is good to use it in case there is no crossing between the PES in a mode.
"""
[docs]class OptimizedDiagParameters:
[docs] def __init__(self, displacement_vector: np.ndarray = None, normal_mode: int = None, data_db: np.ndarray = None,
diab_functions: List[str] = None, vib_frequencies: np.ndarray = None, symmetry_point_group: str = None,
sym_mode: str = None, VCSystem = None):
"""
Initialize the parameters for optimization.
"""
try:
self.normal_mode = normal_mode
logger.debug("normal_mode: %s", normal_mode)
except Exception as e:
logger.exception("An error occurred during initialization: %s", e)
raise
if VCSystem is not None:
logger.debug("Initializing from VCSystem object.")
displacement_vector = getattr(VCSystem, 'displacement_vector', displacement_vector)[normal_mode]
data_db = getattr(VCSystem, 'database_abinitio', data_db)
diab_functions = getattr(VCSystem, 'diab_funct', diab_functions)[normal_mode]
vib_frequencies = getattr(VCSystem, 'vib_freq', vib_frequencies)
symmetry_point_group = getattr(VCSystem, 'symmetry_point_group', symmetry_point_group)
sym_mode = getattr(VCSystem, 'symmetry_modes', sym_mode)[normal_mode]
try:
logger.debug("Initializing parameters for optimization.")
self.displacement_vector = displacement_vector
logger.debug("displacement_vector: %s", displacement_vector)
self.normal_mode = normal_mode
logger.debug("normal_mode: %d", normal_mode)
self.data_db = data_db
logger.debug("data_db: %s", data_db)
self.diab_functions = [funct.lower() for funct in diab_functions]
logger.debug("diab_functions: %s", self.diab_functions)
self.vib_frequencies = vib_frequencies
logger.debug("vib_frequencies: %s", vib_frequencies)
self.omega = self.vib_frequencies[self.normal_mode]
logger.debug("omega: %f", self.omega)
self.symmetry_point_group = symmetry_point_group.upper()
logger.debug("symmetry_point_group: %s", self.symmetry_point_group)
self.sym_mode = sym_mode.upper()
logger.debug("sym_mode: %s", self.sym_mode)
self.potential_functions = {
"ho": lambda q: diabfunct.harmonic_oscillator(q, self.omega),
"quartic": lambda q, k2, k3: diabfunct.general_quartic_potential(q, self.omega, k2, k3),
"morse": diabfunct.general_morse_np,
"antimorse": diabfunct.general_morse_np,
"kappa": diabfunct.kappa
}
logger.debug("potential_functions initialized.")
except Exception as e:
logger.exception("An error occurred during initialization: %s", e)
raise
[docs] def optimize_parameters(self) -> Tuple[np.ndarray, np.ndarray]:
"""
Optimize parameters for each dataset and return them as TensorFlow variables.
Returns:
optimized_parameters: TensorFlow Variables for the optimized parameters
e0_constants: TensorFlow Tensor for the vertical shifts
"""
# Initializing lists
optimized_parameters = []
e0_constants = []
prev_popt = None
# Some data shape for debugging
data_shape = np.shape(self.data_db[self.normal_mode])
logger.info(f"Data shape for normal mode {self.normal_mode}: {data_shape}")
for state, abinit_data in enumerate(self.data_db[self.normal_mode]):
logger.info(f"Processing state: {state}")
function_type = self.diab_functions[state]
if function_type not in self.potential_functions:
logger.error(f"Unknown function type '{function_type}' encountered.")
raise
# Finding the Franck-Condon point
close_to_zero_indices = np.where(np.abs(self.displacement_vector) < 1e-6)[0]
if close_to_zero_indices.size == 0:
logger.error("Did not find 0.0 point in displacement_vector.")
raise ValueError("Did not find 0.0 point")
first_index = close_to_zero_indices[0]
# Shifting energies
min_energy = abinit_data[first_index]
abinit_data_adj = abinit_data - min_energy
# Initialize intial guess or retake it from previous optimization
initial_guess = prev_popt if prev_popt is not None else diabfunct.initial_guesses[function_type]
# Check diagonal diabatic function for a state
potential_fn = self.potential_functions[function_type]
# Check for totally symmetric mode
total_sym_irrep = symm_vcham.SymmetryMask.get_total_sym_irrep(self.symmetry_point_group)
# Add a kappa if the irrep is totally symm (morse potential have the shift inclued (q-q0))
if function_type in ["ho", "quartic"] and self.sym_mode == total_sym_irrep:
potential_fn_trans = lambda q, *args: potential_fn(q, *args[1:]) + diabfunct.kappa(q, args[0])
if np.shape(initial_guess)[0] == diabfunct.n_var[function_type]:
initial_guess_new = np.concatenate((np.array([0.04098793]), initial_guess))
else:
initial_guess_new = initial_guess
# Fit procedure
try:
popt_kappa, _ = curve_fit(potential_fn_trans, self.displacement_vector, abinit_data_adj, p0=initial_guess_new)
prev_popt = popt_kappa
tf_params = list(popt_kappa)
optimized_parameters.append(popt_kappa)
e0_constants.append(min_energy)
except RuntimeError as e:
logger.error(f"Curve fitting failed for state {state}: {e}")
continue
else:
# Fit procedure
try:
popt, _ = curve_fit(potential_fn, self.displacement_vector, abinit_data_adj, p0=initial_guess)
prev_popt = popt
tf_params = list(popt)
optimized_parameters.append(tf_params)
e0_constants.append(min_energy)
except RuntimeError as e:
logger.error(f"Curve fitting failed for state {state}: {e}")
continue
# Save parameters and energy shifts
optimized_parameters = np.array(optimized_parameters)
e0_constants = np.array(e0_constants)
return optimized_parameters, e0_constants
[docs] def plot_data(self, fitted_params, e0_shift, state):
"""
Plot the original data and the fitted potential (Morse or Quartic).
Parameters:
- fitted_params: list or np.array, the parameters obtained from fitting
- e0_shift: float, the vertical energy shift
- state: int, the index of the state data set to plot
"""
plt.figure(figsize=(10, 6))
# Extract ab initio data for the specified state and plot it
abinit_data = self.data_db[self.normal_mode][state]
plt.scatter(self.displacement_vector, abinit_data, label=f'Abinit S{state}', color='blue')
# Retrieve the potential function based on the specified state from the diab_functions list
potential_fn = self.potential_functions[self.diab_functions[state]]
# Check if there is a kappa term in the fitted parameters
if len(fitted_params[state]) != diabfunct.n_var[self.diab_functions[state]]:
logger.info(f"In mode {self.normal_mode}, state {state}, kappa in non-zero")
# Modify the potential function to include the kappa term
potential_fn_trans = lambda q, *args: diabfunct.kappa(q, args[0]) + potential_fn(q, *args[1:])
# Calculate the fitted potential curve using the potential function and fitted parameters
fitted_curve = potential_fn_trans(self.displacement_vector, *fitted_params[state]) + e0_shift
# Plot the fitted potential curve
plt.plot(self.displacement_vector, fitted_curve, label=f'Fitted S{state}', color='red')
# Set plot labels and title
plt.xlabel('Displacement')
plt.ylabel('Potential Energy')
plt.title(f'Fit for Mode {self.normal_mode} State {state}')
plt.legend()
plt.show()