Source code for opt_diag

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import diabfunct
import symm_vcham
from typing import List, Tuple

from logging_config import get_logger
logger = get_logger(__name__)

"""
This module is good to use it in case there is no crossing between the PES in a mode.
"""

[docs]class OptimizedDiagParameters:
[docs] def __init__(self, displacement_vector: np.ndarray = None, normal_mode: int = None, data_db: np.ndarray = None, diab_functions: List[str] = None, vib_frequencies: np.ndarray = None, symmetry_point_group: str = None, sym_mode: str = None, VCSystem = None): """ Initialize the parameters for optimization. """ try: self.normal_mode = normal_mode logger.debug("normal_mode: %s", normal_mode) except Exception as e: logger.exception("An error occurred during initialization: %s", e) raise if VCSystem is not None: logger.debug("Initializing from VCSystem object.") displacement_vector = getattr(VCSystem, 'displacement_vector', displacement_vector)[normal_mode] data_db = getattr(VCSystem, 'database_abinitio', data_db) diab_functions = getattr(VCSystem, 'diab_funct', diab_functions)[normal_mode] vib_frequencies = getattr(VCSystem, 'vib_freq', vib_frequencies) symmetry_point_group = getattr(VCSystem, 'symmetry_point_group', symmetry_point_group) sym_mode = getattr(VCSystem, 'symmetry_modes', sym_mode)[normal_mode] try: logger.debug("Initializing parameters for optimization.") self.displacement_vector = displacement_vector logger.debug("displacement_vector: %s", displacement_vector) self.normal_mode = normal_mode logger.debug("normal_mode: %d", normal_mode) self.data_db = data_db logger.debug("data_db: %s", data_db) self.diab_functions = [funct.lower() for funct in diab_functions] logger.debug("diab_functions: %s", self.diab_functions) self.vib_frequencies = vib_frequencies logger.debug("vib_frequencies: %s", vib_frequencies) self.omega = self.vib_frequencies[self.normal_mode] logger.debug("omega: %f", self.omega) self.symmetry_point_group = symmetry_point_group.upper() logger.debug("symmetry_point_group: %s", self.symmetry_point_group) self.sym_mode = sym_mode.upper() logger.debug("sym_mode: %s", self.sym_mode) self.potential_functions = { "ho": lambda q: diabfunct.harmonic_oscillator(q, self.omega), "quartic": lambda q, k2, k3: diabfunct.general_quartic_potential(q, self.omega, k2, k3), "morse": diabfunct.general_morse_np, "antimorse": diabfunct.general_morse_np, "kappa": diabfunct.kappa } logger.debug("potential_functions initialized.") except Exception as e: logger.exception("An error occurred during initialization: %s", e) raise
[docs] def optimize_parameters(self) -> Tuple[np.ndarray, np.ndarray]: """ Optimize parameters for each dataset and return them as TensorFlow variables. Returns: optimized_parameters: TensorFlow Variables for the optimized parameters e0_constants: TensorFlow Tensor for the vertical shifts """ # Initializing lists optimized_parameters = [] e0_constants = [] prev_popt = None # Some data shape for debugging data_shape = np.shape(self.data_db[self.normal_mode]) logger.info(f"Data shape for normal mode {self.normal_mode}: {data_shape}") for state, abinit_data in enumerate(self.data_db[self.normal_mode]): logger.info(f"Processing state: {state}") function_type = self.diab_functions[state] if function_type not in self.potential_functions: logger.error(f"Unknown function type '{function_type}' encountered.") raise # Finding the Franck-Condon point close_to_zero_indices = np.where(np.abs(self.displacement_vector) < 1e-6)[0] if close_to_zero_indices.size == 0: logger.error("Did not find 0.0 point in displacement_vector.") raise ValueError("Did not find 0.0 point") first_index = close_to_zero_indices[0] # Shifting energies min_energy = abinit_data[first_index] abinit_data_adj = abinit_data - min_energy # Initialize intial guess or retake it from previous optimization initial_guess = prev_popt if prev_popt is not None else diabfunct.initial_guesses[function_type] # Check diagonal diabatic function for a state potential_fn = self.potential_functions[function_type] # Check for totally symmetric mode total_sym_irrep = symm_vcham.SymmetryMask.get_total_sym_irrep(self.symmetry_point_group) # Add a kappa if the irrep is totally symm (morse potential have the shift inclued (q-q0)) if function_type in ["ho", "quartic"] and self.sym_mode == total_sym_irrep: potential_fn_trans = lambda q, *args: potential_fn(q, *args[1:]) + diabfunct.kappa(q, args[0]) if np.shape(initial_guess)[0] == diabfunct.n_var[function_type]: initial_guess_new = np.concatenate((np.array([0.04098793]), initial_guess)) else: initial_guess_new = initial_guess # Fit procedure try: popt_kappa, _ = curve_fit(potential_fn_trans, self.displacement_vector, abinit_data_adj, p0=initial_guess_new) prev_popt = popt_kappa tf_params = list(popt_kappa) optimized_parameters.append(popt_kappa) e0_constants.append(min_energy) except RuntimeError as e: logger.error(f"Curve fitting failed for state {state}: {e}") continue else: # Fit procedure try: popt, _ = curve_fit(potential_fn, self.displacement_vector, abinit_data_adj, p0=initial_guess) prev_popt = popt tf_params = list(popt) optimized_parameters.append(tf_params) e0_constants.append(min_energy) except RuntimeError as e: logger.error(f"Curve fitting failed for state {state}: {e}") continue # Save parameters and energy shifts optimized_parameters = np.array(optimized_parameters) e0_constants = np.array(e0_constants) return optimized_parameters, e0_constants
[docs] def plot_data(self, fitted_params, e0_shift, state): """ Plot the original data and the fitted potential (Morse or Quartic). Parameters: - fitted_params: list or np.array, the parameters obtained from fitting - e0_shift: float, the vertical energy shift - state: int, the index of the state data set to plot """ plt.figure(figsize=(10, 6)) # Extract ab initio data for the specified state and plot it abinit_data = self.data_db[self.normal_mode][state] plt.scatter(self.displacement_vector, abinit_data, label=f'Abinit S{state}', color='blue') # Retrieve the potential function based on the specified state from the diab_functions list potential_fn = self.potential_functions[self.diab_functions[state]] # Check if there is a kappa term in the fitted parameters if len(fitted_params[state]) != diabfunct.n_var[self.diab_functions[state]]: logger.info(f"In mode {self.normal_mode}, state {state}, kappa in non-zero") # Modify the potential function to include the kappa term potential_fn_trans = lambda q, *args: diabfunct.kappa(q, args[0]) + potential_fn(q, *args[1:]) # Calculate the fitted potential curve using the potential function and fitted parameters fitted_curve = potential_fn_trans(self.displacement_vector, *fitted_params[state]) + e0_shift # Plot the fitted potential curve plt.plot(self.displacement_vector, fitted_curve, label=f'Fitted S{state}', color='red') # Set plot labels and title plt.xlabel('Displacement') plt.ylabel('Potential Energy') plt.title(f'Fit for Mode {self.normal_mode} State {state}') plt.legend() plt.show()