slurmify

1from .core import slurm_entrypoint
2from .utils import batch
3
4__all__ = ["slurm_entrypoint", "batch"]
def slurm_entrypoint(**slurm_kwargs):
23def slurm_entrypoint(**slurm_kwargs):
24    """
25        Args:
26            **slurm_kwargs: Additional keyword arguments to configure the SLURM job submission.
27                    slurm_array_parallelism (int, optional): Enables job array mode if provided.
28                    folder (str, optional): Submitit log directory (default: 'slurm_logs' or 'local_logs').
29
30        Returns:
31            Callable: A decorator that wraps the target function for SLURM job submission.
32
33        Raises:
34            ValueError: If `slurm_array_parallelism` is used and not all inputs (except 'use_slurm') are lists/tuples,
35                        or if the input lists do not have the same length.
36
37        Example:
38            @slurm_entrypoint(slurm_array_parallelism=4, folder='my_logs')
39            def my_function(x, y, use_slurm=False):
40                # Function implementation
41                pass"
42    """
43
44    def decorator(fn: Callable) -> Callable:
45        @wraps(fn)
46        def wrapper(*args, **kwargs) -> Any:
47            sig = inspect.signature(fn)
48            bound_args = sig.bind(*args, **kwargs)
49            bound_args.apply_defaults()
50
51            use_slurm = bound_args.arguments.get("use_slurm", False)
52            slurm_array_parallelism = slurm_kwargs.get("slurm_array_parallelism")
53            is_array = slurm_array_parallelism is not None
54            is_remote = use_slurm and is_slurm_available()
55
56            executor_class = submitit.AutoExecutor if is_remote else submitit.LocalExecutor
57            executor_label = "SLURM" if is_remote else "local"
58            folder = slurm_kwargs.get("folder", f"{executor_label.lower()}_logs")
59
60            logger.info(f"[slurmify] Using {executor_label}Executor. Logs in '{folder}'")
61
62            executor = executor_class(folder=folder)
63            executor.update_parameters(**slurm_kwargs)
64
65            if is_array:
66                arg_names = [k for k in bound_args.arguments if k != "use_slurm"]
67                arg_lists = [bound_args.arguments[k] for k in arg_names]
68
69                if not all(isinstance(arg, (list, tuple)) for arg in arg_lists):
70                    raise ValueError("[slurmify] All inputs (except 'use_slurm') must be lists/tuples when slurm_array_parallelism is used.")
71
72                if not all(len(arg_lists[0]) == len(arg) for arg in arg_lists):
73                    raise ValueError("[slurmify] All input lists must have the same length.")
74
75                jobs = executor.map_array(fn, *arg_lists)
76                logger.info(f"[slurmify] Submitted job array with job ids: {[job.job_id for job in jobs]}")
77                return jobs
78
79            else:
80                job = executor.submit(fn, *args, **kwargs)
81                logger.info(f"[slurmify] Submitted job with id {job.job_id}")
82                return job
83
84        return wrapper
85    return decorator

Args: **slurm_kwargs: Additional keyword arguments to configure the SLURM job submission. slurm_array_parallelism (int, optional): Enables job array mode if provided. folder (str, optional): Submitit log directory (default: 'slurm_logs' or 'local_logs').

Returns: Callable: A decorator that wraps the target function for SLURM job submission.

Raises: ValueError: If slurm_array_parallelism is used and not all inputs (except 'use_slurm') are lists/tuples, or if the input lists do not have the same length.

Example: @slurm_entrypoint(slurm_array_parallelism=4, folder='my_logs') def my_function(x, y, use_slurm=False): # Function implementation pass"

def batch( batch_size: int, *args: list[typing.Any]) -> Iterator[Tuple[list[Any], ...]]:
 4def batch(batch_size: int, *args: list[Any]) -> Iterator[Tuple[list[Any], ...]]:
 5    """
 6    Splits multiple lists into batches of a specified size.
 7    Args:
 8        batch_size (int): The size of each batch.
 9        *args (list[Any]): One or more lists to be batched. All lists must be of the same length.
10    Yields:
11        Iterator[Tuple[list[Any], ...]]: An iterator over tuples, where each tuple contains slices of the input lists.
12    Raises:
13        ValueError: If the input lists are not of the same length.
14    """
15
16    if not args:
17        return
18
19    length = len(args[0])
20    if not all(len(arg) == length for arg in args):
21        raise ValueError("All input lists must be the same length.")
22
23    for i in range(0, length, batch_size):
24        yield tuple(arg[i:i + batch_size] for arg in args)

Splits multiple lists into batches of a specified size. Args: batch_size (int): The size of each batch. *args (list[Any]): One or more lists to be batched. All lists must be of the same length. Yields: Iterator[Tuple[list[Any], ...]]: An iterator over tuples, where each tuple contains slices of the input lists. Raises: ValueError: If the input lists are not of the same length.