Source code for yaflux._step

import functools
import inspect
import time
from typing import Any, Callable, Optional, TypeVar, Union

from yaflux._ast import validate_ast
from yaflux._base import Base
from yaflux._metadata import Metadata
from yaflux._results._lock import ResultsLock

T = TypeVar("T")


def _pull_flags(arglist: list[str]) -> tuple[list[str], list[str]]:
    args = []
    flags = []
    for arg in arglist:
        if arg.startswith("_"):
            flags.append(arg)
        else:
            args.append(arg)
    return args, flags


def _normalize_list(value: Optional[Union[list[str], str]]) -> list[str]:
    """Convert string or list input to normalized list."""
    if isinstance(value, str):
        return [value]
    return value or []


def _validate_instance_method(args: tuple) -> tuple[Base, tuple]:
    """Ensure the decorated function is called as instance method."""
    if not args or not isinstance(args[0], Base):
        raise ValueError("Analysis steps must be called as instance methods")
    return args[0], args[1:]


def _check_requirements(
    analysis: Base, requires: list[str], mutates: list[str]
) -> None:
    """Validate that all required results exist."""
    all_requirements = requires + mutates
    missing = [req for req in all_requirements if not hasattr(analysis._results, req)]
    if missing:
        raise ValueError(
            f"Missing required results: {missing}. Run required steps first."
        )


def _check_required_flags(analysis: Base, requires: list[str]) -> None:
    """Validate that all required flags exist."""
    missing = [req for req in requires if not hasattr(analysis._results, req)]
    if missing:
        raise ValueError(f"Missing required flag: {missing}. Run required steps first.")


def _handle_existing_attributes(
    analysis: Base, creates: list[str], force: bool, panic: bool
) -> Optional[dict[str, Any]]:
    """Handle cases where results already exist."""
    for attr in creates:
        if hasattr(analysis._results, attr):
            if force:
                with ResultsLock.allow_mutation():
                    delattr(analysis._results, attr)
            elif panic:
                raise ValueError(f"Attribute {attr} already exists!")
            else:
                print(f"Attribute {attr} already exists - skipping step.")
                return {attr: getattr(analysis._results, attr) for attr in creates}
    return None


def _store_dict_results(
    analysis: Base, creates: list[str], result: dict[str, Any]
) -> None:
    """Store the function results in the analysis object.

    Assumes the keys of the dictionary are the names of the results.
    """
    # Case where the returned dictionary is missing keys from the creates list
    for c in creates:
        if c not in result:
            raise ValueError(f"Missing result key: {c}")

    # Case where the returned dictionary has unexpected keys not in the creates list
    for attr in result:
        if attr not in creates:
            raise AttributeError(f"Unexpected result key: {attr}")

    # Store the results in the analysis object
    for attr, value in result.items():
        setattr(analysis._results, attr, value)


def _store_inferred_tuple_results(
    analysis: Base, creates: list[str], result: tuple
) -> None:
    """Store the function results in the analysis object.

    Assumes the order of the tuple is the same as the order of the creates list.

    Will panic if the length of the tuple doesn't match the length of the creates list.
    """
    if len(result) != len(creates):
        raise ValueError("Tuple result must have the same length as the creates list")

    for attr, value in zip(creates, result):
        setattr(analysis._results, attr, value)


def _store_singular_result(analysis: Base, creates: list[str], result: Any) -> None:
    """Store the function results in the analysis object."""
    if len(creates) != 1:
        raise ValueError("Single result must have exactly one name in the creates list")
    setattr(analysis._results, creates[0], result)


def _store_results(
    analysis: Base,
    creates: list[str],
    result: Any,
) -> None:
    """Store the function results in the analysis object."""

    # If the result is a dictionary, unpack it into the results object
    if isinstance(result, dict):
        try:  # Try to store the dictionary results
            _store_dict_results(analysis, creates, result)
        except (
            ValueError
        ):  # Case where the dictionary is missing keys from the creates list
            _store_singular_result(analysis, creates, result)
        except (
            AttributeError
        ):  # Case where the dictionary has a superset of keys not in the creates list
            raise ValueError(
                "Unambiguous result keys in dictionary (superset of creates list)"
            )

    # If the result is a tuple, unpack it into the results object
    elif isinstance(result, tuple):
        try:
            _store_inferred_tuple_results(analysis, creates, result)

        # Case where the tuple is not the same length as the creates list
        # This means that the user is returning a singular value as a tuple
        except ValueError:
            _store_singular_result(analysis, creates, result)

    # Base base where the result is a single value
    elif result is not None and len(creates) == 1:
        _store_singular_result(analysis, creates, result)


def _set_flags(analysis: Base, flags: list[str]) -> None:
    """Set the flags on the analysis object."""
    for flag in flags:
        setattr(analysis._results, flag, True)


def _store_metadata(
    analysis: Base,
    step_name: str,
    metadata: Metadata,
):
    """Store the step metadata within the analysis object."""
    analysis._results.set_metadata(step_name, metadata)


def _filter_valid_kwargs(func: Callable, kwargs: dict) -> dict:
    """Remove kwargs that aren't in the function signature."""
    sig = inspect.signature(func)
    return {k: v for k, v in kwargs.items() if k in sig.parameters}


[docs] def step( creates: list[str] | str | None = None, requires: list[str] | str | None = None, mutates: list[str] | str | None = None, ) -> Callable: """Decorator to register analysis steps and their results. Parameters ---------- creates : str | list[str] | None Names of the results this step creates requires : str | list[str] | None Names of the results this step requires mutates: str | list[str] | None Names of the results this step mutates Attributes ---------- creates : list[str] Names of the results this step creates requires : list[str] Names of the results this step requires mutates: list[str] Names of the results this step mutates creates_flags : list[str] Names of the flags this step creates requires_flags : list[str] Names of the flags this step requires """ creates_list = _normalize_list(creates) requires_list = _normalize_list(requires) mutates_list = _normalize_list(mutates) creates_list, creates_flags = _pull_flags(creates_list) requires_list, requires_flags = _pull_flags(requires_list) def decorator(func: Callable[..., T]) -> Callable[..., T]: # Validate AST before wrapping the function validate_ast(func, requires=requires_list, mutates=mutates_list) @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> T: # Identify the step name step_name = func.__name__ # Extract control flags force = kwargs.pop("force", False) panic_on_existing = kwargs.pop("panic_on_existing", False) # Setup and validation analysis_obj, remaining_args = _validate_instance_method(args) _check_requirements(analysis_obj, requires_list, mutates_list) _check_required_flags(analysis_obj, requires_flags) # Handle existing results existing = _handle_existing_attributes( analysis_obj, creates_list, force, panic_on_existing ) if existing is not None: return existing # type: ignore # Filter valid kwargs valid_kwargs = _filter_valid_kwargs(func, kwargs) # Setup mutable context mutable_keys = set(mutates_list) if mutates_list else None with ResultsLock.allow_mutation(mutable_keys): # Timestamp the start of the step start_time = time.time() # Execute the function result = func(analysis_obj, *remaining_args, **valid_kwargs) # Record the elapsed time elapsed = time.time() - start_time # Build the metadata object step_metadata = Metadata( creates=creates_list, requires=requires_list, timestamp=start_time, elapsed=elapsed, args=[str(arg) for arg in remaining_args], kwargs={k: str(v) for k, v in valid_kwargs.items()}, ) with ResultsLock.allow_mutation(): # Store the results _store_results(analysis_obj, creates_list, result) # Set the flags _set_flags(analysis_obj, creates_flags) # Store the metadata _store_metadata(analysis_obj, step_name, step_metadata) # Mark completion analysis_obj._completed_steps.add(step_name) # Add to ordering if not already present if step_name not in analysis_obj._step_ordering: analysis_obj._step_ordering.append(step_name) return result # Store metadata wrapper.creates = creates_list # type: ignore wrapper.requires = requires_list # type: ignore wrapper.creates_flags = creates_flags # type: ignore wrapper.requires_flags = requires_flags # type: ignore wrapper.mutates = mutates_list # type: ignore return wrapper return decorator