Source code for lvc

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import diabfunct
import couplingfunct
import symm_vcham
import time
from typing import List, Optional, Tuple, Any, Dict
from dataclasses import dataclass, field

from logging_config import get_logger
logger = get_logger(__name__)


[docs]class LVCHam: """ Builder and optimizer for an LVC Hamiltonian. """
[docs] def __init__( self, normal_mode: int, VCSystem: Any, funct_guess: Optional[Any] = None, nepochs: int = 2000, ) -> None: """ Initialize LVCHam with references to the system and its parameters. Parameters ---------- normal_mode : int The vibrational mode to process. VCSystem : object The parent system containing parameters and data. funct_guess : optional Initial guess parameters for the diabatic functions. nepochs : int, optional Number of optimization epochs. """ self._validate_constructor_input(normal_mode, VCSystem) self.normal_mode: int = normal_mode self.VCSystem: Any = VCSystem self.guess_params: Optional[Any] = funct_guess self.nepochs: int = nepochs self.optimizer = tf.optimizers.Adam(learning_rate=0.001) logger.info("\n\n------------- INIT LVC HAMILTONIAN BUILDER -------------") logger.info("normal_mode: %s", normal_mode) # Setup system-dependent attributes self._initialize_object_params() # Extract from VCSystem self._initialize_internal_variables() # Prepare bookkeeping variables
# --------------------------------------------------------------------------- # Off-diagonal Parameter Generation # --------------------------------------------------------------------------- def _create_off_diag_coupling_param( self, lambda_guess: float = 0.01, jt_guess: float = 0.01 ) -> Tuple[List[List[int]], List[List[int]]]: """ Generate off-diagonal coupling parameters for the Hamiltonian. Returns ------- tuple A tuple containing two lists: - lambda indices (upper-triangular entries with value 1) - JT off-diagonal indices (upper-triangular entries with value 2) """ # Isolate the upper-triangular portion (excluding the diagonal) upper_triangular = tf.linalg.band_part(self.symmetry_mask, 0, -1) - tf.linalg.diag( tf.linalg.diag_part(self.symmetry_mask) ) # Identify indices where off-diagonal lambda contributions are present lambda_idx_tensor = tf.where(tf.equal(upper_triangular, 1)) self.lambda_idx: List[List[int]] = [[int(i[0]), int(i[1])] for i in lambda_idx_tensor.numpy()] if self.lambda_idx: self.lambda_param = self._process_guess(lambda_guess, len(self.lambda_idx), name="lambda_param") logger.info("Lambda nonzero pairs: %s", self.lambda_idx) else: logger.info("No off-diagonal lambda needed for this mode.") self.lambda_param = None # Process JT off-diagonal indices (symmetry_mask == 2) self.create_jt_off_diag(jt_guess=jt_guess) self._initialize_diab_fn_variables() return self.lambda_idx, self.jt_off_idx
[docs] def create_jt_off_diag(self, jt_guess: float = 0.01) -> None: """ Identify JT-type off-diagonal terms (symmetry_mask == 2) and initialize JT parameters. """ upper_triangular = tf.linalg.band_part(self.symmetry_mask, 0, -1) - tf.linalg.diag( tf.linalg.diag_part(self.symmetry_mask) ) jt_off_tensor = tf.where(tf.equal(upper_triangular, 2)) self.jt_off_idx: List[List[int]] = [[int(i[0]), int(i[1])] for i in jt_off_tensor.numpy()] logger.info("JT off-diagonal pairs: %s", self.jt_off_idx) if self.jt_off_idx: self.jt_off_param = self._process_guess(jt_guess, len(self.jt_off_idx), name="jt_off_param") else: self.jt_off_param = None
# --------------------------------------------------------------------------- # On-diagonal Parameter Generation # --------------------------------------------------------------------------- def _create_on_diag_coupling_param(self, kappa_guess: float = 0.1) -> Tuple[List[int], List[int]]: """ Generate on-diagonal coupling (kappa) parameters for the Hamiltonian. Returns ------- tuple A tuple containing: - A list of indices (states) for which kappa is active. - A list of JT on-diagonal indices (states with symmetry mask value 2). """ diag_sym_mask = tf.linalg.diag_part(self.symmetry_mask) kappa_candidates = tf.where(tf.equal(diag_sym_mask, 1)) self.kappa_idx: List[int] = [] for idx in kappa_candidates.numpy(): st = int(idx[0]) if diabfunct.kappa_compatible[self.diab_functions[st]]: self.kappa_idx.append(st) if self.kappa_idx: self.kappa_param = self._process_guess(kappa_guess, len(self.kappa_idx), name="kappa_param") logger.info("Non-zero kappa states: %s", self.kappa_idx) else: logger.info("No on-diagonal kappa parameters needed for this mode.") self.kappa_param = None # Mark summary output for on-diagonal states for s in self.kappa_idx: self.summary_output[s] = "kappa" jt_on_tensor = tf.where(tf.equal(diag_sym_mask, 2)) self.jt_on_idx: List[int] = [int(i) for i in jt_on_tensor.numpy()] return self.kappa_idx, self.jt_on_idx # --------------------------------------------------------------------------- # Parameter Initialization # ---------------------------------------------------------------------------
[docs] def initialize_params( self, lambda_guess: float = 0.1, jt_guess: float = 0.01, kappa_guess: float = 0.1 ) -> None: """ Collect all TF Variables that will be optimized. This method should be called after the off-diagonal and on-diagonal parameters have been created. """ idx_on_diag = self._create_on_diag_coupling_param(kappa_guess=kappa_guess) idx_off_diag = self._create_off_diag_coupling_param(lambda_guess=lambda_guess, jt_guess=jt_guess) # Clear any previous optimization parameter list and count self.optimize_params = [] self.ntotal_param = 0 # Collect parameters from various sources if they exist. if self.funct_param is not None and self.funct_param.shape[0] > 0: logger.debug("funct_param shape: %d", self.funct_param.shape[0]) self.optimize_params.append(self.funct_param) self.ntotal_param += int(np.prod(self.funct_param.shape)) if self.lambda_param is not None: self.optimize_params.append(self.lambda_param) self.ntotal_param += int(np.prod(self.lambda_param.shape)) if self.kappa_param is not None: self.optimize_params.append(self.kappa_param) self.ntotal_param += int(np.prod(self.kappa_param.shape)) if hasattr(self, "jt_on_param") and self.jt_on_param is not None: self.optimize_params.append(self.jt_on_param) self.ntotal_param += int(np.prod(self.jt_on_param.shape)) if self.jt_off_param is not None: self.optimize_params.append(self.jt_off_param) self.ntotal_param += int(np.prod(self.jt_off_param.shape)) logger.info("Total number of parameters to optimize: %s", self.ntotal_param) logger.info("Optimize_params: %s", self.optimize_params) # Update indices in the VCSystem's bookkeeping dictionary. self.VCSystem.idx_dict["kappa"][self.normal_mode] = idx_on_diag[0] self.VCSystem.idx_dict["jt_on"][self.normal_mode] = idx_on_diag[1] self.VCSystem.idx_dict["lambda"][self.normal_mode] = idx_off_diag[0] self.VCSystem.idx_dict["jt_off"][self.normal_mode] = idx_off_diag[1]
[docs] def initialize_loss_function(self, fn: str = "huber", **kwargs) -> None: """ Initialize a TensorFlow loss function by name. Parameters ---------- fn : str, optional Name of the loss function (default is 'huber'). kwargs : Additional keyword arguments for the loss function. """ loss_fn_map = { "huber": lambda: tf.keras.losses.Huber(**kwargs), "mse": lambda: tf.keras.losses.MeanSquaredError(), "mae": lambda: tf.keras.losses.MeanAbsoluteError(), "msle": lambda: tf.keras.losses.MeanSquaredLogarithmicError(), "logcosh": lambda: tf.keras.losses.LogCosh(), "kld": lambda: tf.keras.losses.KLDivergence(), "poisson": lambda: tf.keras.losses.Poisson(), "cosine": lambda: tf.keras.losses.CosineSimilarity(), "sparse": lambda: tf.keras.losses.SparseCategoricalCrossentropy(), "binary": lambda: tf.keras.losses.BinaryCrossentropy(), } fn_lower = fn.lower() if fn_lower not in loss_fn_map: raise NotImplementedError(f"Loss function '{fn_lower}' not implemented.") self.loss_fn = loss_fn_map[fn_lower]()
# --------------------------------------------------------------------------- # Optimization and Training Loop # ---------------------------------------------------------------------------
[docs] def optimize(self) -> None: """ Run gradient-based optimization for the specified number of epochs or until early stopping is triggered. If the current mode is marked inactive, build the final Hamiltonian tensor directly without optimization. """ # Check for inactive mode (e.g., for an Exe JT effect where parameters are copied) inactive_mode = False if self.JT_effects is not None: for effect in self.JT_effects: if effect.get("mode") == self.normal_mode and effect.get("active", True) is False: inactive_mode = True break if inactive_mode or self.ntotal_param == 0: final_tensor = self.build_vcham_tensor() logger.info("Mode inactive or no parameters to optimize; final tensor built directly.") self.plot_results(final_tensor.numpy(), loss_history=[]) logger.info("------- No optimization performed -------") self.VCSystem.summary_output[self.normal_mode] = self.summary_output self.VCSystem.optimized_params[self.normal_mode] = self.optimize_params return t0 = time.perf_counter() best_loss = float("inf") patience = 10000 patience_counter = 0 loss_history: List[float] = [] for step in range(self.nepochs): loss_val = self.train_step() current_loss = float(loss_val.numpy()) loss_history.append(current_loss) if step % 100 == 0: logger.info("Step %d, Loss: %.6f", step, current_loss) if current_loss < (best_loss - 1e-5): best_loss = current_loss patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience: logger.info("Early stopping at step %d.", step) break tf_time = time.perf_counter() - t0 final_tensor = self.build_vcham_tensor().numpy() self.plot_results(final_tensor, loss_history) if not inactive_mode and hasattr(self, "active_jt_states"): self._save_jt() logger.info("Summary_output: %s", self.summary_output) logger.info("Final loss: %.6f", best_loss) logger.info("Optimization finished in %.2f seconds.", tf_time) self.VCSystem.summary_output[self.normal_mode] = self.summary_output self.VCSystem.optimized_params[self.normal_mode] = self.optimize_params
[docs] def plot_results(self, final_tensor_np: np.ndarray, loss_history: List[float]) -> None: """ Plot the final eigenvalues vs. reference data and the loss evolution. Parameters ---------- final_tensor_np : np.ndarray The final Hamiltonian eigenvalues. loss_history : list of float The loss evolution during optimization. """ disp_np = ( self.displacement_vector.numpy() if isinstance(self.displacement_vector, tf.Tensor) else self.displacement_vector ) data_db_np = ( self.data_db.numpy() if isinstance(self.data_db, tf.Tensor) else self.data_db ) if self.coupling_with_gs: plt.figure(figsize=(10, 6)) for i in range(final_tensor_np.shape[0]): plt.plot(disp_np, final_tensor_np[i], label=f"State {i}") plt.scatter(disp_np, data_db_np[i], s=5) plt.xlabel(f"$\\text{{Q}}_{{{self.normal_mode}}} - \\text{{irrep: {self.sym_mode}}}$") plt.ylabel(f"Energy [{self.VCSystem.units}]") plt.legend(prop={"size": 6}) plt.show() else: plt.figure(figsize=(10, 6)) plt.plot(disp_np, final_tensor_np[0], label="Ground State") plt.scatter(disp_np, data_db_np[0], s=5) plt.xlabel(f"$\\text{{Q}}_{{{self.normal_mode}}} - \\text{{irrep: {self.sym_mode}}}$") plt.ylabel(f"Energy [{self.VCSystem.units}]") plt.legend(prop={"size": 7}) plt.show() plt.figure(figsize=(10, 6)) for i in range(1, final_tensor_np.shape[0]): plt.plot(disp_np, final_tensor_np[i], label=f"State {i}") plt.scatter(disp_np, data_db_np[i], s=5) plt.xlabel(f"$\\text{{Q}}_{{{self.normal_mode}}} - \\text{{irrep: {self.sym_mode}}}$") plt.ylabel(f"Energy [{self.VCSystem.units}]") plt.legend(prop={"size": 7}) plt.show() if loss_history: plt.figure(figsize=(5, 3)) plt.plot(loss_history) plt.xlabel("Epoch") plt.ylabel("Loss") plt.grid(True) plt.show()
# --------------------------------------------------------------------------- # TensorFlow Graph Methods # ---------------------------------------------------------------------------
[docs] @tf.function def build_vcham_tensor(self) -> tf.Tensor: """ Assemble the final Hamiltonian tensor from diagonal and off-diagonal contributions. Returns ------- tf.Tensor The Hamiltonian eigenvalue tensor. """ diag_vals = self._build_diagonal_potentials() # shape: [n_states, n_disp] # Build initial diagonal tensor with eigenvalue assembly. vcham_tensor = tf.linalg.diag(tf.transpose(diag_vals)) off_indices, off_values = self._collect_off_diagonal_contributions(self.lambda_idx, self.lambda_param) off_idx_jt, off_vals_jt = self._collect_off_diagonal_contributions(self.jt_off_idx, self.jt_off_param) if off_indices: all_idx = tf.concat(off_indices, axis=0) all_vals = tf.concat(off_values, axis=0) vcham_tensor = tf.tensor_scatter_nd_add(vcham_tensor, all_idx, all_vals) if off_idx_jt: all_idx_jt = tf.concat(off_idx_jt, axis=0) all_vals_jt = tf.concat(off_vals_jt, axis=0) vcham_tensor = tf.tensor_scatter_nd_add(vcham_tensor, all_idx_jt, all_vals_jt) eigenvals = tf.linalg.eigvalsh(vcham_tensor) return tf.transpose(eigenvals)
[docs] @tf.function def train_step(self) -> tf.Tensor: """ A single training step: compute loss, gradients, and update parameters. Returns ------- tf.Tensor The computed loss value. """ with tf.GradientTape() as tape: loss_val = self._cost_function() grads = tape.gradient(loss_val, self.optimize_params) grads, _ = tf.clip_by_global_norm(grads, 1.0) self.optimizer.apply_gradients(zip(grads, self.optimize_params)) return loss_val
def _build_diagonal_potentials(self) -> tf.Tensor: """ Assemble diagonal potential energy contributions for each electronic state. Returns ------- tf.Tensor A tensor stacking the potential for each state. """ n_disp = tf.shape(self.displacement_vector)[0] n_st = self.nstates diag_per_state = [] idx_offset = 0 jt_idx_offset = 0 # Separate offset for JT states for s in range(n_st): ftype = self.diab_functions[s] if s in self.non_jt_states: # Use the non-JT parameters (self.funct_param) if ftype in diabfunct.potential_functions: n_vars = diabfunct.n_var.get(ftype, 0) if n_vars > 0 and self.funct_param is not None: chunk = self.funct_param[idx_offset : idx_offset + n_vars] idx_offset += n_vars else: chunk = None pot_term = self._assemble_on_diag_term(s, ftype, chunk) else: pot_term = tf.zeros([n_disp], dtype=tf.float32) elif s in self.active_jt_states: # Use the JT parameters (self.jt_on_param) if ftype in diabfunct.potential_functions: n_vars = diabfunct.n_var.get(ftype, 0) if n_vars > 0 and self.jt_on_param is not None: chunk = self.jt_on_param[jt_idx_offset : jt_idx_offset + n_vars] else: chunk = None pot_term = self._assemble_on_diag_term(s, ftype, chunk) else: pot_term = tf.zeros([n_disp], dtype=tf.float32) elif s in self.inactive_jt_states: # Use the JT parameters (self.jt_on_param) if ftype in diabfunct.potential_functions: n_vars = diabfunct.n_var.get(ftype, 0) if n_vars > 0 and self.jt_on_param is not None: chunk = self.jt_on_param[jt_idx_offset : jt_idx_offset + n_vars] jt_idx_offset += n_vars else: chunk = None pot_term = self._assemble_on_diag_term(s, ftype, chunk) else: pot_term = tf.zeros([n_disp], dtype=tf.float32) else: pot_term = tf.zeros([n_disp], dtype=tf.float32) diag_per_state.append(pot_term) return tf.stack(diag_per_state, axis=0) def _assemble_on_diag_term(self, state_idx: int, func_name: str, param_var: Optional[tf.Tensor]) -> tf.Tensor: """ Assemble a single on-diagonal potential term for a given state. Parameters ---------- state_idx : int The index of the state. func_name : str The name of the diabatic potential function. param_var : tf.Tensor, optional The parameter chunk for the potential function. Returns ------- tf.Tensor The assembled on-diagonal potential term. """ disp = self.displacement_vector e0_shift = self.e0_shifts[state_idx] pot_fn = diabfunct.potential_functions[func_name] if diabfunct.kappa_compatible[func_name]: base_term = pot_fn(disp, self.omega, param_var) if ( self.sym_mode == self.total_sym_irrep and self.kappa_param is not None and state_idx in self.kappa_idx ): k_idx = self.kappa_idx.index(state_idx) kappa_val = self.kappa_param[k_idx] base_term += couplingfunct.linear_coupling(disp, kappa_val) elif self.inactive_mode: logger.debug("Inactive mode on state %d, using JT off-diagonal parameters.", state_idx) k_idx = self.active_jt_states.index(state_idx) if self.sign_k_params == 1: kappa_val = self.jt_off_param[k_idx] else: kappa_val = self.jt_off_param[k_idx - 1] kappa_val *= self.sign_k_params base_term += couplingfunct.linear_coupling(disp, kappa_val) self.sign_k_params *= -1 else: base_term = pot_fn(disp, param_var) return base_term + e0_shift @tf.function def _collect_off_diagonal_contributions( self, idx_list: List[List[int]], param_var: Optional[tf.Variable], ) -> Tuple[List[tf.Tensor], List[tf.Tensor]]: """ Collect off-diagonal contributions for given indices and parameter variable. Parameters ---------- idx_list : list of list of int List of index pairs where off-diagonal contributions are applied. param_var : tf.Variable, optional Parameter variable associated with these contributions. Returns ------- tuple Two lists: one for indices and one for corresponding values. """ if not idx_list or param_var is None: return [], [] n_disp = tf.shape(self.displacement_vector)[0] off_diag_indices: List[tf.Tensor] = [] off_diag_values: List[tf.Tensor] = [] if self.coupling_order not in couplingfunct.coupling_funct: raise NotImplementedError(f"Coupling '{self.coupling_order}' not implemented.") coupling_fn = couplingfunct.coupling_funct[self.coupling_order] for i, (st1, st2) in enumerate(idx_list): pval = param_var[i] if len(param_var.shape) > 0 else param_var update_vals = coupling_fn(self.displacement_vector, pval) disp_indices = tf.range(n_disp, dtype=tf.int32) st1_indices = tf.fill([n_disp], st1) st2_indices = tf.fill([n_disp], st2) indices_1 = tf.stack([disp_indices, st1_indices, st2_indices], axis=1) indices_2 = tf.stack([disp_indices, st2_indices, st1_indices], axis=1) off_diag_indices.extend([indices_1, indices_2]) off_diag_values.extend([update_vals, update_vals]) return off_diag_indices, off_diag_values @tf.function def _cost_function(self) -> tf.Tensor: """ Compute the cost function as the mean loss between the reference and model data. Returns ------- tf.Tensor The scalar loss value. """ final_tensor = self.build_vcham_tensor() loss_val = self.loss_fn(self.data_db, final_tensor) return tf.reduce_mean(loss_val) # --------------------------------------------------------------------------- # Jahn-Teller Handling # --------------------------------------------------------------------------- def _initialize_diab_fn_variables(self) -> None: """ Prepare diabatic function variables. If JT effects are present for the current mode, prepare JT parameters; otherwise, prepare non-JT parameters. """ self.inactive_mode = False if not self.JT_effects or not any(effect.get("mode") == self.normal_mode for effect in self.JT_effects): self._prepare_non_jt_param() else: self._prepare_jt_param() def _prepare_jt_param(self) -> None: """ Prepare JT parameters for modes with Jahn-Teller effects. If the current mode is inactive, retrieve parameters from the source active mode. Otherwise, process active JT effects normally. """ current_mode_effects = [effect for effect in self.JT_effects if effect.get("mode") == self.normal_mode] inactive_effects = [effect for effect in current_mode_effects if effect.get("active", True) is False] active_effects = [effect for effect in current_mode_effects if effect.get("active", True)] jt_state_pairs: List[Tuple[int, int]] = [] if inactive_effects: source_mode = inactive_effects[0].get("source") if source_mode is None: raise ValueError("Inactive JT effect must include a 'source' key.") for effect in inactive_effects: pairs = effect.get("state_pairs", []) jt_state_pairs.extend(pairs) total_jt_states = set() for pair in jt_state_pairs: if isinstance(pair, (list, tuple)) and len(pair) == 2: total_jt_states.update(pair) else: raise ValueError("Each state pair in JT_effects must be a list or tuple of two integers.") total_jt_states = list(total_jt_states) logger.info("Total JT states from inactive effect: %s", total_jt_states) if hasattr(self.VCSystem, "JT_params") and source_mode in self.VCSystem.JT_params: source_jt = self.VCSystem.JT_params[source_mode] self.jt_on_param = tf.Variable(source_jt["on"], dtype=tf.float32, name="jt_on_param_inactive") self.jt_off_param = tf.Variable(source_jt["off"], dtype=tf.float32, name="jt_off_param_inactive") logger.info("Using JT parameters from source mode %s.", source_mode) self.sign_k_params = 1 self.active_jt_states = total_jt_states all_states = set(range(self.nstates)) self.inactive_mode = True self.non_jt_states = list(all_states - set(total_jt_states)) for s in range(self.nstates): if s in self.non_jt_states: self.summary_output[s] = "" else: self.summary_output[s] = "JT" return else: logger.warning("No JT parameters available from source mode %s. Using default guesses.", source_mode) # Proceed as if active (or define a default behavior below) # Process active JT effects normally. for effect in active_effects: pairs = effect.get("state_pairs", []) jt_state_pairs.extend(pairs) active_jt_states = set() inactive_jt_states = set() total_jt_states = set() for pair in jt_state_pairs: if isinstance(pair, (list, tuple)) and len(pair) == 2: active_jt_states.add(pair[0]) inactive_jt_states.add(pair[1]) total_jt_states.update(pair) else: raise ValueError("Each state pair in JT_effects must be a list or tuple of two integers.") self.active_jt_states = list(active_jt_states) self.inactive_jt_states = list(inactive_jt_states) total_jt_states = list(total_jt_states) logger.info("Active JT states: %s", self.active_jt_states) all_states = set(range(self.nstates)) self.non_jt_states = list(all_states - set(total_jt_states)) diab_non_jt_diab_fn = [self.diab_functions[f] for f in self.non_jt_states] diab_jt_diab_fn = [self.diab_functions[f] for f in self.active_jt_states] self.n_var_list = [diabfunct.n_var[f] for f in diab_non_jt_diab_fn] self.n_var_jt = [diabfunct.n_var[f] for f in diab_jt_diab_fn] if self.guess_params is None: non_jt_guess = [diabfunct.initial_guesses[f] for f in diab_non_jt_diab_fn] flat_guess = np.array(non_jt_guess).flatten().tolist() self.funct_param = tf.Variable(flat_guess, dtype=tf.float32, name="funct_param") jt_guess = [diabfunct.initial_guesses[f] for f in diab_jt_diab_fn] flat_guess = np.array(jt_guess).flatten().tolist() self.jt_on_param = tf.Variable(flat_guess, dtype=tf.float32, name="jt_on_param") logger.info("JT on guess: %s", jt_guess) # Set summary output for each state. for s in range(self.nstates): if s in self.non_jt_states: self.summary_output[s] = "" else: self.summary_output[s] = "JT" def _prepare_non_jt_param(self) -> None: """ Prepare parameters when no JT effects are present. """ self.non_jt_states = list(range(self.nstates)) if self.guess_params is None: self.guess_params = [diabfunct.initial_guesses[f] for f in self.diab_functions] self.n_var_list = [diabfunct.n_var[f] for f in self.diab_functions] logger.info("n_var_list (non-JT): %s", self.n_var_list) if self.guess_params is not None: flat_guess = np.array(self.guess_params).flatten().tolist() self.funct_param = tf.Variable(flat_guess, dtype=tf.float32, name="funct_param") else: self.funct_param = None def _save_jt(self) -> None: """ Save the optimized JT parameters back to the VCSystem. """ jt_values: Dict[str, Any] = {"mode": self.normal_mode, "params": {"on": None, "off": None}} if getattr(self, "jt_on_param", None) is not None: jt_values["params"]["on"] = self.jt_on_param.numpy() if getattr(self, "jt_off_param", None) is not None: jt_values["params"]["off"] = self.jt_off_param.numpy() self.VCSystem.append_JT_param(jt_values) logger.info("JT parameters saved for mode %s: %s", self.normal_mode, jt_values) # --------------------------------------------------------------------------- # Core Initialization # --------------------------------------------------------------------------- def _initialize_object_params(self) -> None: """ Extract parameters from the VCSystem and store them as TensorFlow tensors. """ sys_ = self.VCSystem nm = self.normal_mode self.nstates: int = sys_.number_states self.displacement_vector = tf.convert_to_tensor(sys_.displacement_vector[nm], dtype=tf.float32) self.data_db = tf.convert_to_tensor(sys_.database_abinitio[nm], dtype=tf.float32) self.diab_functions = [f.lower() for f in sys_.diab_funct[nm]] self._validate_diabatic_functions() self.coupling_with_gs = sys_.coupling_with_gs self.omega = sys_.vib_freq[nm] self.e0_shifts = tf.convert_to_tensor(sys_.energy_shift, dtype=tf.float32) self.symmetry_mask = tf.convert_to_tensor(sys_.symmetry_matrix[nm], dtype=tf.float32) self.symmetry_point_group = sys_.symmetry_point_group.upper() self.sym_mode = sys_.symmetry_modes[nm].upper() self.JT_effects = sys_.JT_effects logger.info("JT effects: %s", self.JT_effects) self.total_sym_irrep = sys_.totally_sym_irrep coupling_order = getattr(sys_, "vc_type", "linear").lower() if coupling_order not in couplingfunct.COUPLING_TYPES: raise NotImplementedError(f"Coupling function '{coupling_order}' is not implemented.") self.coupling_order = coupling_order self.idx_dict = sys_.idx_dict def _validate_constructor_input(self, normal_mode: int, VCSystem: Any) -> None: """ Validate critical constructor parameters. Parameters ---------- normal_mode : int The mode index. VCSystem : object The system instance. Raises ------ ValueError If inputs are invalid. """ if not isinstance(normal_mode, int) or normal_mode < 0: raise ValueError("Normal mode must be a non-negative integer.") if VCSystem is None: raise ValueError("A valid VCSystem must be provided.") required_attrs = [ "displacement_vector", "database_abinitio", "diab_funct", "vib_freq", "symmetry_matrix", "symmetry_point_group" ] missing = [attr for attr in required_attrs if not hasattr(VCSystem, attr)] if missing: raise ValueError(f"VCSystem is missing required attributes: {missing}") def _validate_diabatic_functions(self) -> None: """ Ensure every specified diabatic function is implemented in diabfunct. """ valid_functions = set(diabfunct.potential_functions.keys()) for func in self.diab_functions: if func not in valid_functions: raise ValueError(f"Unsupported function type: {func}") def _initialize_internal_variables(self) -> None: """ Initialize internal bookkeeping variables. """ self.summary_output = ["" for _ in range(self.nstates)] self.kappa_idx: List[int] = [] self.lambda_idx: List[List[int]] = [] self.ntotal_param = 0 self.funct_param = None self.kappa_param = None self.lambda_param = None self.jt_on_param = None self.jt_off_param = None self.n_var_list: List[Any] = [] self.optimize_params: List[Any] = [] def _process_guess(self, guess: Any, n_pairs: int, name: str = "param") -> tf.Variable: """ Convert a guess value into a tf.Variable of the correct length. Parameters ---------- guess : float or list The initial guess (scalar or list). n_pairs : int The number of parameters expected. name : str, optional The name for the TensorFlow variable. Returns ------- tf.Variable The created variable. """ if isinstance(guess, (float, int)): arr_guess = [guess] * n_pairs elif isinstance(guess, list): if len(guess) == n_pairs: arr_guess = guess elif len(guess) < n_pairs: arr_guess = (guess * (n_pairs // len(guess) + 1))[:n_pairs] else: raise ValueError(f"Off-diagonal guess must be a scalar or list of length {n_pairs}.") else: raise ValueError("Guess must be a scalar or a list.") return tf.Variable(arr_guess, dtype=tf.float32, name=name)