slurmify
23def slurm_entrypoint(**slurm_kwargs): 24 """ 25 Args: 26 **slurm_kwargs: Additional keyword arguments to configure the SLURM job submission. 27 slurm_array_parallelism (int, optional): Enables job array mode if provided. 28 folder (str, optional): Submitit log directory (default: 'slurm_logs' or 'local_logs'). 29 30 Returns: 31 Callable: A decorator that wraps the target function for SLURM job submission. 32 33 Raises: 34 ValueError: If `slurm_array_parallelism` is used and not all inputs (except 'use_slurm') are lists/tuples, 35 or if the input lists do not have the same length. 36 37 Example: 38 @slurm_entrypoint(slurm_array_parallelism=4, folder='my_logs') 39 def my_function(x, y, use_slurm=False): 40 # Function implementation 41 pass" 42 """ 43 44 def decorator(fn: Callable) -> Callable: 45 @wraps(fn) 46 def wrapper(*args, **kwargs) -> Any: 47 sig = inspect.signature(fn) 48 bound_args = sig.bind(*args, **kwargs) 49 bound_args.apply_defaults() 50 51 use_slurm = bound_args.arguments.get("use_slurm", False) 52 slurm_array_parallelism = slurm_kwargs.get("slurm_array_parallelism") 53 is_array = slurm_array_parallelism is not None 54 is_remote = use_slurm and is_slurm_available() 55 56 executor_class = submitit.AutoExecutor if is_remote else submitit.LocalExecutor 57 executor_label = "SLURM" if is_remote else "local" 58 folder = slurm_kwargs.get("folder", f"{executor_label.lower()}_logs") 59 60 logger.info(f"[slurmify] Using {executor_label}Executor. Logs in '{folder}'") 61 62 executor = executor_class(folder=folder) 63 executor.update_parameters(**slurm_kwargs) 64 65 if is_array: 66 arg_names = [k for k in bound_args.arguments if k != "use_slurm"] 67 arg_lists = [bound_args.arguments[k] for k in arg_names] 68 69 if not all(isinstance(arg, (list, tuple)) for arg in arg_lists): 70 raise ValueError("[slurmify] All inputs (except 'use_slurm') must be lists/tuples when slurm_array_parallelism is used.") 71 72 if not all(len(arg_lists[0]) == len(arg) for arg in arg_lists): 73 raise ValueError("[slurmify] All input lists must have the same length.") 74 75 jobs = executor.map_array(fn, *arg_lists) 76 logger.info(f"[slurmify] Submitted job array with job ids: {[job.job_id for job in jobs]}") 77 return jobs 78 79 else: 80 job = executor.submit(fn, *args, **kwargs) 81 logger.info(f"[slurmify] Submitted job with id {job.job_id}") 82 return job 83 84 return wrapper 85 return decorator
Args: **slurm_kwargs: Additional keyword arguments to configure the SLURM job submission. slurm_array_parallelism (int, optional): Enables job array mode if provided. folder (str, optional): Submitit log directory (default: 'slurm_logs' or 'local_logs').
Returns: Callable: A decorator that wraps the target function for SLURM job submission.
Raises:
ValueError: If slurm_array_parallelism
is used and not all inputs (except 'use_slurm') are lists/tuples,
or if the input lists do not have the same length.
Example: @slurm_entrypoint(slurm_array_parallelism=4, folder='my_logs') def my_function(x, y, use_slurm=False): # Function implementation pass"
4def batch(batch_size: int, *args: list[Any]) -> Iterator[Tuple[list[Any], ...]]: 5 """ 6 Splits multiple lists into batches of a specified size. 7 Args: 8 batch_size (int): The size of each batch. 9 *args (list[Any]): One or more lists to be batched. All lists must be of the same length. 10 Yields: 11 Iterator[Tuple[list[Any], ...]]: An iterator over tuples, where each tuple contains slices of the input lists. 12 Raises: 13 ValueError: If the input lists are not of the same length. 14 """ 15 16 if not args: 17 return 18 19 length = len(args[0]) 20 if not all(len(arg) == length for arg in args): 21 raise ValueError("All input lists must be the same length.") 22 23 for i in range(0, length, batch_size): 24 yield tuple(arg[i:i + batch_size] for arg in args)
Splits multiple lists into batches of a specified size. Args: batch_size (int): The size of each batch. *args (list[Any]): One or more lists to be batched. All lists must be of the same length. Yields: Iterator[Tuple[list[Any], ...]]: An iterator over tuples, where each tuple contains slices of the input lists. Raises: ValueError: If the input lists are not of the same length.