import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import diabfunct
import couplingfunct
import symm_vcham
import time
from typing import List, Optional, Tuple, Any, Dict
from dataclasses import dataclass, field
from logging_config import get_logger
logger = get_logger(__name__)
[docs]class LVCHam:
"""
Builder and optimizer for an LVC Hamiltonian.
"""
[docs] def __init__(
self,
normal_mode: int,
VCSystem: Any,
funct_guess: Optional[Any] = None,
nepochs: int = 2000,
) -> None:
"""
Initialize LVCHam with references to the system and its parameters.
Parameters
----------
normal_mode : int
The vibrational mode to process.
VCSystem : object
The parent system containing parameters and data.
funct_guess : optional
Initial guess parameters for the diabatic functions.
nepochs : int, optional
Number of optimization epochs.
"""
self._validate_constructor_input(normal_mode, VCSystem)
self.normal_mode: int = normal_mode
self.VCSystem: Any = VCSystem
self.guess_params: Optional[Any] = funct_guess
self.nepochs: int = nepochs
self.optimizer = tf.optimizers.Adam(learning_rate=0.001)
logger.info("\n\n------------- INIT LVC HAMILTONIAN BUILDER -------------")
logger.info("normal_mode: %s", normal_mode)
# Setup system-dependent attributes
self._initialize_object_params() # Extract from VCSystem
self._initialize_internal_variables() # Prepare bookkeeping variables
# ---------------------------------------------------------------------------
# Off-diagonal Parameter Generation
# ---------------------------------------------------------------------------
def _create_off_diag_coupling_param(
self, lambda_guess: float = 0.01, jt_guess: float = 0.01
) -> Tuple[List[List[int]], List[List[int]]]:
"""
Generate off-diagonal coupling parameters for the Hamiltonian.
Returns
-------
tuple
A tuple containing two lists:
- lambda indices (upper-triangular entries with value 1)
- JT off-diagonal indices (upper-triangular entries with value 2)
"""
# Isolate the upper-triangular portion (excluding the diagonal)
upper_triangular = tf.linalg.band_part(self.symmetry_mask, 0, -1) - tf.linalg.diag(
tf.linalg.diag_part(self.symmetry_mask)
)
# Identify indices where off-diagonal lambda contributions are present
lambda_idx_tensor = tf.where(tf.equal(upper_triangular, 1))
self.lambda_idx: List[List[int]] = [[int(i[0]), int(i[1])] for i in lambda_idx_tensor.numpy()]
if self.lambda_idx:
self.lambda_param = self._process_guess(lambda_guess, len(self.lambda_idx), name="lambda_param")
logger.info("Lambda nonzero pairs: %s", self.lambda_idx)
else:
logger.info("No off-diagonal lambda needed for this mode.")
self.lambda_param = None
# Process JT off-diagonal indices (symmetry_mask == 2)
self.create_jt_off_diag(jt_guess=jt_guess)
self._initialize_diab_fn_variables()
return self.lambda_idx, self.jt_off_idx
[docs] def create_jt_off_diag(self, jt_guess: float = 0.01) -> None:
"""
Identify JT-type off-diagonal terms (symmetry_mask == 2) and initialize JT parameters.
"""
upper_triangular = tf.linalg.band_part(self.symmetry_mask, 0, -1) - tf.linalg.diag(
tf.linalg.diag_part(self.symmetry_mask)
)
jt_off_tensor = tf.where(tf.equal(upper_triangular, 2))
self.jt_off_idx: List[List[int]] = [[int(i[0]), int(i[1])] for i in jt_off_tensor.numpy()]
logger.info("JT off-diagonal pairs: %s", self.jt_off_idx)
if self.jt_off_idx:
self.jt_off_param = self._process_guess(jt_guess, len(self.jt_off_idx), name="jt_off_param")
else:
self.jt_off_param = None
# ---------------------------------------------------------------------------
# On-diagonal Parameter Generation
# ---------------------------------------------------------------------------
def _create_on_diag_coupling_param(self, kappa_guess: float = 0.1) -> Tuple[List[int], List[int]]:
"""
Generate on-diagonal coupling (kappa) parameters for the Hamiltonian.
Returns
-------
tuple
A tuple containing:
- A list of indices (states) for which kappa is active.
- A list of JT on-diagonal indices (states with symmetry mask value 2).
"""
diag_sym_mask = tf.linalg.diag_part(self.symmetry_mask)
kappa_candidates = tf.where(tf.equal(diag_sym_mask, 1))
self.kappa_idx: List[int] = []
for idx in kappa_candidates.numpy():
st = int(idx[0])
if diabfunct.kappa_compatible[self.diab_functions[st]]:
self.kappa_idx.append(st)
if self.kappa_idx:
self.kappa_param = self._process_guess(kappa_guess, len(self.kappa_idx), name="kappa_param")
logger.info("Non-zero kappa states: %s", self.kappa_idx)
else:
logger.info("No on-diagonal kappa parameters needed for this mode.")
self.kappa_param = None
# Mark summary output for on-diagonal states
for s in self.kappa_idx:
self.summary_output[s] = "kappa"
jt_on_tensor = tf.where(tf.equal(diag_sym_mask, 2))
self.jt_on_idx: List[int] = [int(i) for i in jt_on_tensor.numpy()]
return self.kappa_idx, self.jt_on_idx
# ---------------------------------------------------------------------------
# Parameter Initialization
# ---------------------------------------------------------------------------
[docs] def initialize_params(
self, lambda_guess: float = 0.1, jt_guess: float = 0.01, kappa_guess: float = 0.1
) -> None:
"""
Collect all TF Variables that will be optimized.
This method should be called after the off-diagonal and on-diagonal
parameters have been created.
"""
idx_on_diag = self._create_on_diag_coupling_param(kappa_guess=kappa_guess)
idx_off_diag = self._create_off_diag_coupling_param(lambda_guess=lambda_guess, jt_guess=jt_guess)
# Clear any previous optimization parameter list and count
self.optimize_params = []
self.ntotal_param = 0
# Collect parameters from various sources if they exist.
if self.funct_param is not None and self.funct_param.shape[0] > 0:
logger.debug("funct_param shape: %d", self.funct_param.shape[0])
self.optimize_params.append(self.funct_param)
self.ntotal_param += int(np.prod(self.funct_param.shape))
if self.lambda_param is not None:
self.optimize_params.append(self.lambda_param)
self.ntotal_param += int(np.prod(self.lambda_param.shape))
if self.kappa_param is not None:
self.optimize_params.append(self.kappa_param)
self.ntotal_param += int(np.prod(self.kappa_param.shape))
if hasattr(self, "jt_on_param") and self.jt_on_param is not None:
self.optimize_params.append(self.jt_on_param)
self.ntotal_param += int(np.prod(self.jt_on_param.shape))
if self.jt_off_param is not None:
self.optimize_params.append(self.jt_off_param)
self.ntotal_param += int(np.prod(self.jt_off_param.shape))
logger.info("Total number of parameters to optimize: %s", self.ntotal_param)
logger.info("Optimize_params: %s", self.optimize_params)
# Update indices in the VCSystem's bookkeeping dictionary.
self.VCSystem.idx_dict["kappa"][self.normal_mode] = idx_on_diag[0]
self.VCSystem.idx_dict["jt_on"][self.normal_mode] = idx_on_diag[1]
self.VCSystem.idx_dict["lambda"][self.normal_mode] = idx_off_diag[0]
self.VCSystem.idx_dict["jt_off"][self.normal_mode] = idx_off_diag[1]
[docs] def initialize_loss_function(self, fn: str = "huber", **kwargs) -> None:
"""
Initialize a TensorFlow loss function by name.
Parameters
----------
fn : str, optional
Name of the loss function (default is 'huber').
kwargs :
Additional keyword arguments for the loss function.
"""
loss_fn_map = {
"huber": lambda: tf.keras.losses.Huber(**kwargs),
"mse": lambda: tf.keras.losses.MeanSquaredError(),
"mae": lambda: tf.keras.losses.MeanAbsoluteError(),
"msle": lambda: tf.keras.losses.MeanSquaredLogarithmicError(),
"logcosh": lambda: tf.keras.losses.LogCosh(),
"kld": lambda: tf.keras.losses.KLDivergence(),
"poisson": lambda: tf.keras.losses.Poisson(),
"cosine": lambda: tf.keras.losses.CosineSimilarity(),
"sparse": lambda: tf.keras.losses.SparseCategoricalCrossentropy(),
"binary": lambda: tf.keras.losses.BinaryCrossentropy(),
}
fn_lower = fn.lower()
if fn_lower not in loss_fn_map:
raise NotImplementedError(f"Loss function '{fn_lower}' not implemented.")
self.loss_fn = loss_fn_map[fn_lower]()
# ---------------------------------------------------------------------------
# Optimization and Training Loop
# ---------------------------------------------------------------------------
[docs] def optimize(self) -> None:
"""
Run gradient-based optimization for the specified number of epochs or
until early stopping is triggered. If the current mode is marked inactive,
build the final Hamiltonian tensor directly without optimization.
"""
# Check for inactive mode (e.g., for an Exe JT effect where parameters are copied)
inactive_mode = False
if self.JT_effects is not None:
for effect in self.JT_effects:
if effect.get("mode") == self.normal_mode and effect.get("active", True) is False:
inactive_mode = True
break
if inactive_mode or self.ntotal_param == 0:
final_tensor = self.build_vcham_tensor()
logger.info("Mode inactive or no parameters to optimize; final tensor built directly.")
self.plot_results(final_tensor.numpy(), loss_history=[])
logger.info("------- No optimization performed -------")
self.VCSystem.summary_output[self.normal_mode] = self.summary_output
self.VCSystem.optimized_params[self.normal_mode] = self.optimize_params
return
t0 = time.perf_counter()
best_loss = float("inf")
patience = 10000
patience_counter = 0
loss_history: List[float] = []
for step in range(self.nepochs):
loss_val = self.train_step()
current_loss = float(loss_val.numpy())
loss_history.append(current_loss)
if step % 100 == 0:
logger.info("Step %d, Loss: %.6f", step, current_loss)
if current_loss < (best_loss - 1e-5):
best_loss = current_loss
patience_counter = 0
else:
patience_counter += 1
if patience_counter >= patience:
logger.info("Early stopping at step %d.", step)
break
tf_time = time.perf_counter() - t0
final_tensor = self.build_vcham_tensor().numpy()
self.plot_results(final_tensor, loss_history)
if not inactive_mode and hasattr(self, "active_jt_states"):
self._save_jt()
logger.info("Summary_output: %s", self.summary_output)
logger.info("Final loss: %.6f", best_loss)
logger.info("Optimization finished in %.2f seconds.", tf_time)
self.VCSystem.summary_output[self.normal_mode] = self.summary_output
self.VCSystem.optimized_params[self.normal_mode] = self.optimize_params
[docs] def plot_results(self, final_tensor_np: np.ndarray, loss_history: List[float]) -> None:
"""
Plot the final eigenvalues vs. reference data and the loss evolution.
Parameters
----------
final_tensor_np : np.ndarray
The final Hamiltonian eigenvalues.
loss_history : list of float
The loss evolution during optimization.
"""
disp_np = (
self.displacement_vector.numpy()
if isinstance(self.displacement_vector, tf.Tensor)
else self.displacement_vector
)
data_db_np = (
self.data_db.numpy()
if isinstance(self.data_db, tf.Tensor)
else self.data_db
)
if self.coupling_with_gs:
plt.figure(figsize=(10, 6))
for i in range(final_tensor_np.shape[0]):
plt.plot(disp_np, final_tensor_np[i], label=f"State {i}")
plt.scatter(disp_np, data_db_np[i], s=5)
plt.xlabel(f"$\\text{{Q}}_{{{self.normal_mode}}} - \\text{{irrep: {self.sym_mode}}}$")
plt.ylabel(f"Energy [{self.VCSystem.units}]")
plt.legend(prop={"size": 6})
plt.show()
else:
plt.figure(figsize=(10, 6))
plt.plot(disp_np, final_tensor_np[0], label="Ground State")
plt.scatter(disp_np, data_db_np[0], s=5)
plt.xlabel(f"$\\text{{Q}}_{{{self.normal_mode}}} - \\text{{irrep: {self.sym_mode}}}$")
plt.ylabel(f"Energy [{self.VCSystem.units}]")
plt.legend(prop={"size": 7})
plt.show()
plt.figure(figsize=(10, 6))
for i in range(1, final_tensor_np.shape[0]):
plt.plot(disp_np, final_tensor_np[i], label=f"State {i}")
plt.scatter(disp_np, data_db_np[i], s=5)
plt.xlabel(f"$\\text{{Q}}_{{{self.normal_mode}}} - \\text{{irrep: {self.sym_mode}}}$")
plt.ylabel(f"Energy [{self.VCSystem.units}]")
plt.legend(prop={"size": 7})
plt.show()
if loss_history:
plt.figure(figsize=(5, 3))
plt.plot(loss_history)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.grid(True)
plt.show()
# ---------------------------------------------------------------------------
# TensorFlow Graph Methods
# ---------------------------------------------------------------------------
[docs] @tf.function
def build_vcham_tensor(self) -> tf.Tensor:
"""
Assemble the final Hamiltonian tensor from diagonal and off-diagonal contributions.
Returns
-------
tf.Tensor
The Hamiltonian eigenvalue tensor.
"""
diag_vals = self._build_diagonal_potentials() # shape: [n_states, n_disp]
# Build initial diagonal tensor with eigenvalue assembly.
vcham_tensor = tf.linalg.diag(tf.transpose(diag_vals))
off_indices, off_values = self._collect_off_diagonal_contributions(self.lambda_idx, self.lambda_param)
off_idx_jt, off_vals_jt = self._collect_off_diagonal_contributions(self.jt_off_idx, self.jt_off_param)
if off_indices:
all_idx = tf.concat(off_indices, axis=0)
all_vals = tf.concat(off_values, axis=0)
vcham_tensor = tf.tensor_scatter_nd_add(vcham_tensor, all_idx, all_vals)
if off_idx_jt:
all_idx_jt = tf.concat(off_idx_jt, axis=0)
all_vals_jt = tf.concat(off_vals_jt, axis=0)
vcham_tensor = tf.tensor_scatter_nd_add(vcham_tensor, all_idx_jt, all_vals_jt)
eigenvals = tf.linalg.eigvalsh(vcham_tensor)
return tf.transpose(eigenvals)
[docs] @tf.function
def train_step(self) -> tf.Tensor:
"""
A single training step: compute loss, gradients, and update parameters.
Returns
-------
tf.Tensor
The computed loss value.
"""
with tf.GradientTape() as tape:
loss_val = self._cost_function()
grads = tape.gradient(loss_val, self.optimize_params)
grads, _ = tf.clip_by_global_norm(grads, 1.0)
self.optimizer.apply_gradients(zip(grads, self.optimize_params))
return loss_val
def _build_diagonal_potentials(self) -> tf.Tensor:
"""
Assemble diagonal potential energy contributions for each electronic state.
Returns
-------
tf.Tensor
A tensor stacking the potential for each state.
"""
n_disp = tf.shape(self.displacement_vector)[0]
n_st = self.nstates
diag_per_state = []
idx_offset = 0
jt_idx_offset = 0 # Separate offset for JT states
for s in range(n_st):
ftype = self.diab_functions[s]
if s in self.non_jt_states:
# Use the non-JT parameters (self.funct_param)
if ftype in diabfunct.potential_functions:
n_vars = diabfunct.n_var.get(ftype, 0)
if n_vars > 0 and self.funct_param is not None:
chunk = self.funct_param[idx_offset : idx_offset + n_vars]
idx_offset += n_vars
else:
chunk = None
pot_term = self._assemble_on_diag_term(s, ftype, chunk)
else:
pot_term = tf.zeros([n_disp], dtype=tf.float32)
elif s in self.active_jt_states:
# Use the JT parameters (self.jt_on_param)
if ftype in diabfunct.potential_functions:
n_vars = diabfunct.n_var.get(ftype, 0)
if n_vars > 0 and self.jt_on_param is not None:
chunk = self.jt_on_param[jt_idx_offset : jt_idx_offset + n_vars]
else:
chunk = None
pot_term = self._assemble_on_diag_term(s, ftype, chunk)
else:
pot_term = tf.zeros([n_disp], dtype=tf.float32)
elif s in self.inactive_jt_states:
# Use the JT parameters (self.jt_on_param)
if ftype in diabfunct.potential_functions:
n_vars = diabfunct.n_var.get(ftype, 0)
if n_vars > 0 and self.jt_on_param is not None:
chunk = self.jt_on_param[jt_idx_offset : jt_idx_offset + n_vars]
jt_idx_offset += n_vars
else:
chunk = None
pot_term = self._assemble_on_diag_term(s, ftype, chunk)
else:
pot_term = tf.zeros([n_disp], dtype=tf.float32)
else:
pot_term = tf.zeros([n_disp], dtype=tf.float32)
diag_per_state.append(pot_term)
return tf.stack(diag_per_state, axis=0)
def _assemble_on_diag_term(self, state_idx: int, func_name: str, param_var: Optional[tf.Tensor]) -> tf.Tensor:
"""
Assemble a single on-diagonal potential term for a given state.
Parameters
----------
state_idx : int
The index of the state.
func_name : str
The name of the diabatic potential function.
param_var : tf.Tensor, optional
The parameter chunk for the potential function.
Returns
-------
tf.Tensor
The assembled on-diagonal potential term.
"""
disp = self.displacement_vector
e0_shift = self.e0_shifts[state_idx]
pot_fn = diabfunct.potential_functions[func_name]
if diabfunct.kappa_compatible[func_name]:
base_term = pot_fn(disp, self.omega, param_var)
if (
self.sym_mode == self.total_sym_irrep
and self.kappa_param is not None
and state_idx in self.kappa_idx
):
k_idx = self.kappa_idx.index(state_idx)
kappa_val = self.kappa_param[k_idx]
base_term += couplingfunct.linear_coupling(disp, kappa_val)
elif self.inactive_mode:
logger.debug("Inactive mode on state %d, using JT off-diagonal parameters.", state_idx)
k_idx = self.active_jt_states.index(state_idx)
if self.sign_k_params == 1:
kappa_val = self.jt_off_param[k_idx]
else:
kappa_val = self.jt_off_param[k_idx - 1]
kappa_val *= self.sign_k_params
base_term += couplingfunct.linear_coupling(disp, kappa_val)
self.sign_k_params *= -1
else:
base_term = pot_fn(disp, param_var)
return base_term + e0_shift
@tf.function
def _collect_off_diagonal_contributions(
self,
idx_list: List[List[int]],
param_var: Optional[tf.Variable],
) -> Tuple[List[tf.Tensor], List[tf.Tensor]]:
"""
Collect off-diagonal contributions for given indices and parameter variable.
Parameters
----------
idx_list : list of list of int
List of index pairs where off-diagonal contributions are applied.
param_var : tf.Variable, optional
Parameter variable associated with these contributions.
Returns
-------
tuple
Two lists: one for indices and one for corresponding values.
"""
if not idx_list or param_var is None:
return [], []
n_disp = tf.shape(self.displacement_vector)[0]
off_diag_indices: List[tf.Tensor] = []
off_diag_values: List[tf.Tensor] = []
if self.coupling_order not in couplingfunct.coupling_funct:
raise NotImplementedError(f"Coupling '{self.coupling_order}' not implemented.")
coupling_fn = couplingfunct.coupling_funct[self.coupling_order]
for i, (st1, st2) in enumerate(idx_list):
pval = param_var[i] if len(param_var.shape) > 0 else param_var
update_vals = coupling_fn(self.displacement_vector, pval)
disp_indices = tf.range(n_disp, dtype=tf.int32)
st1_indices = tf.fill([n_disp], st1)
st2_indices = tf.fill([n_disp], st2)
indices_1 = tf.stack([disp_indices, st1_indices, st2_indices], axis=1)
indices_2 = tf.stack([disp_indices, st2_indices, st1_indices], axis=1)
off_diag_indices.extend([indices_1, indices_2])
off_diag_values.extend([update_vals, update_vals])
return off_diag_indices, off_diag_values
@tf.function
def _cost_function(self) -> tf.Tensor:
"""
Compute the cost function as the mean loss between the reference and model data.
Returns
-------
tf.Tensor
The scalar loss value.
"""
final_tensor = self.build_vcham_tensor()
loss_val = self.loss_fn(self.data_db, final_tensor)
return tf.reduce_mean(loss_val)
# ---------------------------------------------------------------------------
# Jahn-Teller Handling
# ---------------------------------------------------------------------------
def _initialize_diab_fn_variables(self) -> None:
"""
Prepare diabatic function variables.
If JT effects are present for the current mode, prepare JT parameters;
otherwise, prepare non-JT parameters.
"""
self.inactive_mode = False
if not self.JT_effects or not any(effect.get("mode") == self.normal_mode for effect in self.JT_effects):
self._prepare_non_jt_param()
else:
self._prepare_jt_param()
def _prepare_jt_param(self) -> None:
"""
Prepare JT parameters for modes with Jahn-Teller effects.
If the current mode is inactive, retrieve parameters from the source active mode.
Otherwise, process active JT effects normally.
"""
current_mode_effects = [effect for effect in self.JT_effects if effect.get("mode") == self.normal_mode]
inactive_effects = [effect for effect in current_mode_effects if effect.get("active", True) is False]
active_effects = [effect for effect in current_mode_effects if effect.get("active", True)]
jt_state_pairs: List[Tuple[int, int]] = []
if inactive_effects:
source_mode = inactive_effects[0].get("source")
if source_mode is None:
raise ValueError("Inactive JT effect must include a 'source' key.")
for effect in inactive_effects:
pairs = effect.get("state_pairs", [])
jt_state_pairs.extend(pairs)
total_jt_states = set()
for pair in jt_state_pairs:
if isinstance(pair, (list, tuple)) and len(pair) == 2:
total_jt_states.update(pair)
else:
raise ValueError("Each state pair in JT_effects must be a list or tuple of two integers.")
total_jt_states = list(total_jt_states)
logger.info("Total JT states from inactive effect: %s", total_jt_states)
if hasattr(self.VCSystem, "JT_params") and source_mode in self.VCSystem.JT_params:
source_jt = self.VCSystem.JT_params[source_mode]
self.jt_on_param = tf.Variable(source_jt["on"], dtype=tf.float32, name="jt_on_param_inactive")
self.jt_off_param = tf.Variable(source_jt["off"], dtype=tf.float32, name="jt_off_param_inactive")
logger.info("Using JT parameters from source mode %s.", source_mode)
self.sign_k_params = 1
self.active_jt_states = total_jt_states
all_states = set(range(self.nstates))
self.inactive_mode = True
self.non_jt_states = list(all_states - set(total_jt_states))
for s in range(self.nstates):
if s in self.non_jt_states:
self.summary_output[s] = ""
else:
self.summary_output[s] = "JT"
return
else:
logger.warning("No JT parameters available from source mode %s. Using default guesses.", source_mode)
# Proceed as if active (or define a default behavior below)
# Process active JT effects normally.
for effect in active_effects:
pairs = effect.get("state_pairs", [])
jt_state_pairs.extend(pairs)
active_jt_states = set()
inactive_jt_states = set()
total_jt_states = set()
for pair in jt_state_pairs:
if isinstance(pair, (list, tuple)) and len(pair) == 2:
active_jt_states.add(pair[0])
inactive_jt_states.add(pair[1])
total_jt_states.update(pair)
else:
raise ValueError("Each state pair in JT_effects must be a list or tuple of two integers.")
self.active_jt_states = list(active_jt_states)
self.inactive_jt_states = list(inactive_jt_states)
total_jt_states = list(total_jt_states)
logger.info("Active JT states: %s", self.active_jt_states)
all_states = set(range(self.nstates))
self.non_jt_states = list(all_states - set(total_jt_states))
diab_non_jt_diab_fn = [self.diab_functions[f] for f in self.non_jt_states]
diab_jt_diab_fn = [self.diab_functions[f] for f in self.active_jt_states]
self.n_var_list = [diabfunct.n_var[f] for f in diab_non_jt_diab_fn]
self.n_var_jt = [diabfunct.n_var[f] for f in diab_jt_diab_fn]
if self.guess_params is None:
non_jt_guess = [diabfunct.initial_guesses[f] for f in diab_non_jt_diab_fn]
flat_guess = np.array(non_jt_guess).flatten().tolist()
self.funct_param = tf.Variable(flat_guess, dtype=tf.float32, name="funct_param")
jt_guess = [diabfunct.initial_guesses[f] for f in diab_jt_diab_fn]
flat_guess = np.array(jt_guess).flatten().tolist()
self.jt_on_param = tf.Variable(flat_guess, dtype=tf.float32, name="jt_on_param")
logger.info("JT on guess: %s", jt_guess)
# Set summary output for each state.
for s in range(self.nstates):
if s in self.non_jt_states:
self.summary_output[s] = ""
else:
self.summary_output[s] = "JT"
def _prepare_non_jt_param(self) -> None:
"""
Prepare parameters when no JT effects are present.
"""
self.non_jt_states = list(range(self.nstates))
if self.guess_params is None:
self.guess_params = [diabfunct.initial_guesses[f] for f in self.diab_functions]
self.n_var_list = [diabfunct.n_var[f] for f in self.diab_functions]
logger.info("n_var_list (non-JT): %s", self.n_var_list)
if self.guess_params is not None:
flat_guess = np.array(self.guess_params).flatten().tolist()
self.funct_param = tf.Variable(flat_guess, dtype=tf.float32, name="funct_param")
else:
self.funct_param = None
def _save_jt(self) -> None:
"""
Save the optimized JT parameters back to the VCSystem.
"""
jt_values: Dict[str, Any] = {"mode": self.normal_mode, "params": {"on": None, "off": None}}
if getattr(self, "jt_on_param", None) is not None:
jt_values["params"]["on"] = self.jt_on_param.numpy()
if getattr(self, "jt_off_param", None) is not None:
jt_values["params"]["off"] = self.jt_off_param.numpy()
self.VCSystem.append_JT_param(jt_values)
logger.info("JT parameters saved for mode %s: %s", self.normal_mode, jt_values)
# ---------------------------------------------------------------------------
# Core Initialization
# ---------------------------------------------------------------------------
def _initialize_object_params(self) -> None:
"""
Extract parameters from the VCSystem and store them as TensorFlow tensors.
"""
sys_ = self.VCSystem
nm = self.normal_mode
self.nstates: int = sys_.number_states
self.displacement_vector = tf.convert_to_tensor(sys_.displacement_vector[nm], dtype=tf.float32)
self.data_db = tf.convert_to_tensor(sys_.database_abinitio[nm], dtype=tf.float32)
self.diab_functions = [f.lower() for f in sys_.diab_funct[nm]]
self._validate_diabatic_functions()
self.coupling_with_gs = sys_.coupling_with_gs
self.omega = sys_.vib_freq[nm]
self.e0_shifts = tf.convert_to_tensor(sys_.energy_shift, dtype=tf.float32)
self.symmetry_mask = tf.convert_to_tensor(sys_.symmetry_matrix[nm], dtype=tf.float32)
self.symmetry_point_group = sys_.symmetry_point_group.upper()
self.sym_mode = sys_.symmetry_modes[nm].upper()
self.JT_effects = sys_.JT_effects
logger.info("JT effects: %s", self.JT_effects)
self.total_sym_irrep = sys_.totally_sym_irrep
coupling_order = getattr(sys_, "vc_type", "linear").lower()
if coupling_order not in couplingfunct.COUPLING_TYPES:
raise NotImplementedError(f"Coupling function '{coupling_order}' is not implemented.")
self.coupling_order = coupling_order
self.idx_dict = sys_.idx_dict
def _validate_constructor_input(self, normal_mode: int, VCSystem: Any) -> None:
"""
Validate critical constructor parameters.
Parameters
----------
normal_mode : int
The mode index.
VCSystem : object
The system instance.
Raises
------
ValueError
If inputs are invalid.
"""
if not isinstance(normal_mode, int) or normal_mode < 0:
raise ValueError("Normal mode must be a non-negative integer.")
if VCSystem is None:
raise ValueError("A valid VCSystem must be provided.")
required_attrs = [
"displacement_vector", "database_abinitio", "diab_funct",
"vib_freq", "symmetry_matrix", "symmetry_point_group"
]
missing = [attr for attr in required_attrs if not hasattr(VCSystem, attr)]
if missing:
raise ValueError(f"VCSystem is missing required attributes: {missing}")
def _validate_diabatic_functions(self) -> None:
"""
Ensure every specified diabatic function is implemented in diabfunct.
"""
valid_functions = set(diabfunct.potential_functions.keys())
for func in self.diab_functions:
if func not in valid_functions:
raise ValueError(f"Unsupported function type: {func}")
def _initialize_internal_variables(self) -> None:
"""
Initialize internal bookkeeping variables.
"""
self.summary_output = ["" for _ in range(self.nstates)]
self.kappa_idx: List[int] = []
self.lambda_idx: List[List[int]] = []
self.ntotal_param = 0
self.funct_param = None
self.kappa_param = None
self.lambda_param = None
self.jt_on_param = None
self.jt_off_param = None
self.n_var_list: List[Any] = []
self.optimize_params: List[Any] = []
def _process_guess(self, guess: Any, n_pairs: int, name: str = "param") -> tf.Variable:
"""
Convert a guess value into a tf.Variable of the correct length.
Parameters
----------
guess : float or list
The initial guess (scalar or list).
n_pairs : int
The number of parameters expected.
name : str, optional
The name for the TensorFlow variable.
Returns
-------
tf.Variable
The created variable.
"""
if isinstance(guess, (float, int)):
arr_guess = [guess] * n_pairs
elif isinstance(guess, list):
if len(guess) == n_pairs:
arr_guess = guess
elif len(guess) < n_pairs:
arr_guess = (guess * (n_pairs // len(guess) + 1))[:n_pairs]
else:
raise ValueError(f"Off-diagonal guess must be a scalar or list of length {n_pairs}.")
else:
raise ValueError("Guess must be a scalar or a list.")
return tf.Variable(arr_guess, dtype=tf.float32, name=name)