Source code for dpest.pst

import yaml
import pyemu
import tempfile
from dpest.functions import *

[docs] def pst( cultivar_parameters=None, ecotype_parameters=None, dataframe_observations=None, output_path=None, model_comand_line=None, noptmax=1000, pst_filename='PEST_CONTROL.pst', input_output_file_pairs=None ): """ Create and update a PEST control file (PST) for CERES wheat model calibration. Args: bounds, and groupings. It should include: cultivar_parameters (dict): Dictionary containing model parameters with their values, - 'cultivar_parameters': Current parameter values for the specified cultivar. - 'minima_parameters': Minima values for all parameters. - 'maxima_parameters': Maxima values for all parameters. - 'parameters_grouped': Grouping of parameters. dataframe_observations (pd.DataFrame or list): DataFrame or list of DataFrames containing observations to include in the PST file. Each DataFrame must include columns: 'variable_name', 'value_measured', and 'group' output_path (str): Directory to save the PST file. Defaults to the current working directory if not provided. model_comand_line (str): Command line for running the model executable. noptmax (int): Maximum number of iterations for the optimization process. Default is 1000. pst_filename (str): The name of the PST file to create or update. Default is 'pest_control_ceres_wheat.pst'. input_output_file_pairs (list): List of tuples where each tuple contains an input and output file pair. Returns: None: This function creates the PST file at the specified `output_path` with the provided name (`pst_filename`). It performs validation on inputs, processes observation data, sets up parameters, and writes the resulting PST file. Raises: ValueError: If required arguments are missing or invalid values are encountered. FileNotFoundError: If the specified CUL file does not exist. Exception: For any other unexpected errors. """ # Define default variables yml_pst_file_block = 'PST_FILE' yml_file_observation_groups = 'OBSERVATION_GROUPS_SPECIFICATIONS' try: ## Get the yaml_data # Get the directory of the current script current_dir = os.path.dirname(os.path.abspath(__file__)) # Construct the path to arguments.yml arguments_file = os.path.join(current_dir, 'arguments.yml') # Ensure the YAML file exists if not os.path.isfile(arguments_file): raise FileNotFoundError(f"YAML file not found: {arguments_file}") # Load YAML configuration with open(arguments_file, 'r') as yml_file: yaml_data = yaml.safe_load(yml_file) # Validate inputs if not (cultivar_parameters or ecotype_parameters): raise ValueError( "At least one of `cultivar_parameters` or `ecotype_parameters` must be provided and non-empty.") if cultivar_parameters and not isinstance(cultivar_parameters, dict): raise ValueError("`cultivar_parameters`, if provided, must be a dictionary.") if ecotype_parameters and not isinstance(ecotype_parameters, dict): raise ValueError("`ecotype_parameters`, if provided, must be a dictionary.") # Additional validation for file extensions based on parameters if cultivar_parameters: if not any(pair[1].lower().endswith('.cul') for pair in input_output_file_pairs): raise ValueError( "If `cultivar_parameters` is provided, at least one file in `input_output_file_pairs` must have a '.CUL' extension.") if ecotype_parameters: if not any(pair[1].lower().endswith('.eco') for pair in input_output_file_pairs): raise ValueError( "If `ecotype_parameters` is provided, at least one file in `input_output_file_pairs` must have a '.ECO' extension.") # Validate that at least one file has a '.OUT' extension if not any(pair[1].lower().endswith('.out') for pair in input_output_file_pairs): raise ValueError("At least one file in `input_output_file_pairs` must have a '.OUT' extension.") if dataframe_observations is None: raise ValueError("`dataframe_observations` must be provided.") # Convert single dataframe to list for consistent processing if isinstance(dataframe_observations, pd.DataFrame): dataframe_observations = [dataframe_observations] if not isinstance(dataframe_observations, list) or not all( isinstance(df, pd.DataFrame) for df in dataframe_observations): raise ValueError("`dataframe_observations` must be a DataFrame or a list of DataFrames.") required_columns = {'variable_name', 'value_measured', 'group'} for df in dataframe_observations: if not required_columns.issubset(df.columns): raise ValueError( "Each DataFrame in `dataframe_observations` must contain 'variable_name', 'value_measured', and 'group' columns.") # Get Parameter Group Variables observation_groups = yaml_data[yml_pst_file_block][yml_file_observation_groups] # Merge dictionaries if both are provided, or use the one that exists parameters = { 'parameters': {**(cultivar_parameters.get('parameters', {}) if cultivar_parameters else {}), **(ecotype_parameters.get('parameters', {}) if ecotype_parameters else {})}, 'minima_parameters': {**(cultivar_parameters.get('minima_parameters', {}) if cultivar_parameters else {}), **(ecotype_parameters.get('minima_parameters', {}) if ecotype_parameters else {})}, 'maxima_parameters': {**(cultivar_parameters.get('maxima_parameters', {}) if cultivar_parameters else {}), **(ecotype_parameters.get('maxima_parameters', {}) if ecotype_parameters else {})}, 'parameters_grouped': {**(cultivar_parameters.get('parameters_grouped', {}) if cultivar_parameters else {}), **(ecotype_parameters.get('parameters_grouped', {}) if ecotype_parameters else {})} } # Extract cultivar_parameters all_params = [ param for group in parameters['parameters_grouped'].values() for param in group.replace(' ', '').split(',') ] # Create a minimal PST object pst = pyemu.pst_utils.generic_pst(all_params) # Populate parameters for param in all_params: pst.parameter_data.loc[param, 'parval1'] = float(parameters['parameters'][param]) pst.parameter_data.loc[param, "parlbnd"] = float(parameters['minima_parameters'][param]) pst.parameter_data.loc[param, "parubnd"] = float(parameters['maxima_parameters'][param]) pst.parameter_data.loc[param, "pargp"] = next( (group for group, params in parameters['parameters_grouped'].items() if param in params.split(', ')), None) # Add PARTRANS and PARCHGLIM pst.parameter_data.loc[param, "partrans"] = "none" # Set PARTRANS to none pst.parameter_data.loc[param, "parchglim"] = "relative" # Set PARCHGLIM to relative # Create parameter groups using values from observation_groups pargp_data = [] for group in parameters['parameters_grouped'].keys(): pargp_entry = {"pargpnme": group} # Start with the group name pargp_entry.update(observation_groups) # Update with values from observation_groups pargp_data.append(pargp_entry) # Convert parameter groups list to DataFrame pst.parameter_groups = pd.DataFrame(pargp_data) # Clear existing observation data pst.observation_data = pst.observation_data.iloc[0:0] # Process all dataframes for df in dataframe_observations: # Validate and clean observation data df['value_measured'] = pd.to_numeric(df['value_measured'], errors='coerce') df = df.dropna(subset=['value_measured']) for index, row in df.iterrows(): obsnme = row['variable_name'] obsval = row['value_measured'] obgnme = row['group'] pst.observation_data.loc[obsnme, 'obsnme'] = obsnme pst.observation_data.loc[obsnme, 'obsval'] = obsval pst.observation_data.loc[obsnme, 'obgnme'] = obgnme pst.observation_data.loc[obsnme, 'weight'] = 1.0 # Default weight # ~~~~~~~~ Handle input and output files if input_output_file_pairs: # Validate file pairs if not all(len(pair) == 2 for pair in input_output_file_pairs): raise ValueError("Each input_output_file_pair must contain exactly two elements") if not all(pair[0].lower().endswith(('.tpl', '.ins')) for pair in input_output_file_pairs): raise ValueError("The first element of each pair must be a .tpl or .ins file") # Validate file existence for pair in input_output_file_pairs: validate_file_path(pair[0]) # Validate PEST file (TPL or INS) validate_file_path(pair[1]) # Validate model file # Function to count TPL and INS files def count_file_types(file_pairs): tpl_count = sum(1 for pair in file_pairs if pair[0].lower().endswith('.tpl')) ins_count = sum(1 for pair in file_pairs if pair[0].lower().endswith('.ins')) return tpl_count, ins_count # Add quotes to escape spaces def escape_spaces(file_pairs): return [ (f'"{pair[0]}"' if ' ' in pair[0] else pair[0], f'"{pair[1]}"' if ' ' in pair[1] else pair[1]) for pair in file_pairs ] # Escape spaces in paths input_output_file_pairs = escape_spaces(input_output_file_pairs) # Count TPL and INS files tpl_count, ins_count = count_file_types(input_output_file_pairs) # Set input files (TPL files) pst.model_input_data = pd.DataFrame({ 'pest_file': [pair[0] for pair in input_output_file_pairs if pair[0].strip('"').lower().endswith('.tpl')], 'model_file': [pair[1] for pair in input_output_file_pairs if pair[0].strip('"').lower().endswith('.tpl')] }) # Set output files (INS files) pst.model_output_data = pd.DataFrame({ 'pest_file': [pair[0] for pair in input_output_file_pairs if pair[0].strip('"').lower().endswith('.ins')], 'model_file': [pair[1] for pair in input_output_file_pairs if pair[0].strip('"').lower().endswith('.ins')] }) # Set NTPLFLE and NINSFLE pst.control_data.ntplfle = tpl_count pst.control_data.ninsfle = ins_count # ~~~~~~~~/ Handle input and output files # Set NUMCOM, JACFILE, and MESSFILE pst.control_data.numcom = 1 pst.control_data.jacfile = 0 pst.control_data.messfile = 0 # Set LSQR mode pst.pestmode = "estimation" # ~~~~~~~~ Add LSQR section as a custom attribute pst.lsqr_data = { "lsqrmode": 1, "lsqr_atol": 1e-4, "lsqr_btol": 1e-4, "lsqr_conlim": 28.0, "lsqr_itnlim": 28, "lsqrwrite": 0 } # Store the original write method original_write = pst.write # Define a new write method that replaces SVD with LSQR def custom_write(self, filename): # First, write to a temporary file with tempfile.NamedTemporaryFile(mode='w+', delete=False) as temp_file: original_write(temp_file.name) temp_filename = temp_file.name # Read the content of the temporary file with open(temp_filename, 'r') as f: content = f.read() # Replace SVD section with LSQR lsqr_section = f"* lsqr\n {self.lsqr_data['lsqrmode']}\n {self.lsqr_data['lsqr_atol']} {self.lsqr_data['lsqr_btol']} {self.lsqr_data['lsqr_conlim']} {self.lsqr_data['lsqr_itnlim']}\n {self.lsqr_data['lsqrwrite']}\n" content = re.sub(r'\* singular value decomposition.*?(?=\*|$)', lsqr_section, content, flags=re.DOTALL) # Write modified content to the final file with open(filename, 'w') as f: f.write(content) # Remove the temporary file os.unlink(temp_filename) # Replace the write method pst.write = custom_write.__get__(pst) # ~~~~~~~~/ Add LSQR section as a custom attribute # Set additional control data parameters pst.control_data.rlambda1 = 10.0 pst.control_data.numlam = 10 pst.control_data.icov = 1 pst.control_data.icor = 1 pst.control_data.ieig = 1 # Add the the command used to run the model executable pst.model_command = [model_comand_line] # Add number of iteractions pst.control_data.noptmax = noptmax # Validate output_path output_path = validate_output_path(output_path) # Create the path and name for the file ouput pst_file_path = os.path.join(output_path, pst_filename) # Write the PST file pst.write(pst_file_path) print(f"PST file successfully created: {pst_file_path}") except ValueError as ve: print(f"ValueError: {ve}") except FileNotFoundError as fe: print(f"FileNotFoundError: {fe}") except Exception as e: print(f"An unexpected error occurred: {e}")