cpg_flow.workflow
Provides a Workflow
class and a @stage
decorator that allow to define workflows
in a declarative fashion.
A Stage
object is responsible for creating Hail Batch jobs and declaring outputs
(files or metamist analysis objects) that are expected to be produced. Each stage
acts on a Target
, which can be of the following:
* SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample)
* Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care)
* Cohort - a stratification of SGs in this analysis by Metamist CustomCohort
* MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort
A Workflow
object plugs stages together by resolving dependencies between different levels accordingly. Stages are
defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in
main.py, which provides a way to choose a workflow using a CLI argument.
1""" 2Provides a `Workflow` class and a `@stage` decorator that allow to define workflows 3in a declarative fashion. 4 5A `Stage` object is responsible for creating Hail Batch jobs and declaring outputs 6(files or metamist analysis objects) that are expected to be produced. Each stage 7acts on a `Target`, which can be of the following: 8 9 * SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample) 10 * Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care) 11 * Cohort - a stratification of SGs in this analysis by Metamist CustomCohort 12 * MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort 13 14A `Workflow` object plugs stages together by resolving dependencies between different levels accordingly. Stages are 15defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in 16main.py, which provides a way to choose a workflow using a CLI argument. 17""" 18 19import functools 20from collections import defaultdict 21from collections.abc import Callable 22from enum import Enum 23from typing import TYPE_CHECKING, Optional, Union 24 25import networkx as nx 26 27from cpg_flow.inputs import get_multicohort 28from cpg_flow.status import MetamistStatusReporter 29from cpg_flow.targets import Cohort, MultiCohort 30from cpg_flow.utils import get_logger, slugify, timestamp 31from cpg_utils import Path 32from cpg_utils.config import get_config 33from cpg_utils.hail_batch import get_batch, reset_batch 34 35LOGGER = get_logger(__name__) 36 37if TYPE_CHECKING: 38 from cpg_flow.stage import Stage, StageDecorator, StageOutput 39 40 41def path_walk(expected, collected: set | None = None) -> set[Path]: 42 """ 43 recursive walk of expected_out 44 if the object is iterable, walk it 45 this gets around the issue with nested lists and dicts 46 mainly around the use of Array outputs from Cromwell 47 48 Args: 49 expected (Any): any type of object containing Paths 50 collected (set): all collected paths so far 51 52 Returns: 53 a set of all collected Path nodes 54 55 Examples: 56 57 >>> path_walk({'a': {'b': {'c': Path('d')}}}) 58 {Path('d')} 59 >>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}}) 60 {Path('d'), Path('e')} 61 >>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}}) 62 {Path('b'), Path('g')} 63 """ 64 if collected is None: 65 collected = set() 66 67 if expected is None: 68 return collected 69 if isinstance(expected, dict): 70 for value in expected.values(): 71 collected.update(path_walk(value, collected)) 72 if isinstance(expected, list | set): 73 for value in expected: 74 collected.update(path_walk(value, collected)) 75 if isinstance(expected, str): 76 return collected 77 if isinstance(expected, Path): 78 if expected in collected: 79 raise ValueError(f'Duplicate path {expected} in expected_out') 80 collected.add(expected) 81 return collected 82 83 84class WorkflowError(Exception): 85 """ 86 Error raised by workflow and stage implementation. 87 """ 88 89 90class Action(Enum): 91 """ 92 Indicates what a stage should do with a specific target. 93 """ 94 95 QUEUE = 1 96 SKIP = 2 97 REUSE = 3 98 99 100# noinspection PyUnusedLocal 101def skip( 102 _fun: Optional['StageDecorator'] = None, 103 *, 104 reason: str | None = None, 105 assume_outputs_exist: bool = False, 106) -> Union['StageDecorator', Callable[..., 'StageDecorator']]: 107 """ 108 Decorator on top of `@stage` that sets the `self.skipped` field to True. 109 By default, expected outputs of a skipped stage will be checked, 110 unless `assume_outputs_exist` is True. 111 112 @skip 113 @stage 114 class MyStage1(SequencingGroupStage): 115 ... 116 117 @skip 118 @stage(assume_outputs_exist=True) 119 class MyStage2(SequencingGroupStage): 120 ... 121 """ 122 123 def decorator_stage(fun) -> 'StageDecorator': 124 """Implements decorator.""" 125 126 @functools.wraps(fun) 127 def wrapper_stage(*args, **kwargs) -> 'Stage': 128 """Decorator helper function.""" 129 s = fun(*args, **kwargs) 130 s.skipped = True 131 s.assume_outputs_exist = assume_outputs_exist 132 return s 133 134 return wrapper_stage 135 136 if _fun is None: 137 return decorator_stage 138 return decorator_stage(_fun) 139 140 141_workflow: Optional['Workflow'] = None 142 143 144def get_workflow(dry_run: bool = False) -> 'Workflow': 145 global _workflow 146 if _workflow is None: 147 _workflow = Workflow(dry_run=dry_run) 148 return _workflow 149 150 151def run_workflow( 152 stages: list['StageDecorator'] | None = None, 153 wait: bool | None = False, 154 dry_run: bool = False, 155) -> 'Workflow': 156 wfl = get_workflow(dry_run=dry_run) 157 wfl.run(stages=stages, wait=wait) 158 return wfl 159 160 161class Workflow: 162 """ 163 Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. 164 Responsible for orchestrating stages. 165 """ 166 167 def __init__( 168 self, 169 stages: list['StageDecorator'] | None = None, 170 dry_run: bool | None = None, 171 ): 172 if _workflow is not None: 173 raise ValueError( 174 'Workflow already initialised. Use get_workflow() to get the instance', 175 ) 176 177 self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run') 178 179 # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset? 180 analysis_dataset = get_config(True)['workflow']['dataset'] 181 name = get_config()['workflow'].get('name', analysis_dataset) 182 description = get_config()['workflow'].get('description', name) 183 self.name = slugify(name) 184 185 self._output_version: str | None = None 186 if output_version := get_config()['workflow'].get('output_version'): 187 self._output_version = slugify(output_version) 188 189 self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp() 190 191 # Description 192 if self._output_version: 193 description += f': output_version={self._output_version}' 194 description += f': run_timestamp={self.run_timestamp}' 195 if sequencing_type := get_config()['workflow'].get('sequencing_type'): 196 description += f' [{sequencing_type}]' 197 if not self.dry_run: 198 if ds_set := set(d.name for d in get_multicohort().get_datasets()): 199 description += ' ' + ', '.join(sorted(ds_set)) 200 reset_batch() 201 get_batch().name = description 202 203 self.status_reporter = None 204 if get_config()['workflow'].get('status_reporter') == 'metamist': 205 self.status_reporter = MetamistStatusReporter() 206 self._stages: list[StageDecorator] | None = stages 207 self.queued_stages: list[Stage] = [] 208 209 @property 210 def output_version(self) -> str: 211 return self._output_version or get_multicohort().alignment_inputs_hash() 212 213 @property 214 def analysis_prefix(self) -> Path: 215 return self._prefix(category='analysis') 216 217 @property 218 def tmp_prefix(self) -> Path: 219 return self._prefix(category='tmp') 220 221 @property 222 def web_prefix(self) -> Path: 223 return self._prefix(category='web') 224 225 @property 226 def prefix(self) -> Path: 227 return self._prefix() 228 229 def _prefix(self, category: str | None = None) -> Path: 230 """ 231 Prepare a unique path for the workflow with this name and this input data. 232 """ 233 return get_multicohort().analysis_dataset.prefix(category=category) / self.name / self.output_version 234 235 def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path: 236 """ 237 Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method 238 Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID 239 e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123" 240 241 Args: 242 cohort (Cohort): we pull the analysis dataset and name from this Cohort 243 category (str | None): sub-bucket for this project 244 245 Returns: 246 Path 247 """ 248 return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.name 249 250 def run( 251 self, 252 stages: list['StageDecorator'] | None = None, 253 wait: bool | None = False, 254 ): 255 """ 256 Resolve stages, add and submit Hail Batch jobs. 257 When `run_all_implicit_stages` is set, all required stages that were not defined 258 explicitly would still be executed. 259 """ 260 _stages = stages or self._stages 261 if not _stages: 262 raise WorkflowError('No stages added') 263 self.set_stages(_stages) 264 265 if not self.dry_run: 266 get_batch().run(wait=wait) 267 268 @staticmethod 269 def _process_first_last_stages( 270 stages: list['Stage'], 271 graph: nx.DiGraph, 272 first_stages: list[str], 273 last_stages: list[str], 274 ): 275 """ 276 Applying first_stages and last_stages config options. Would skip all stages 277 before first_stages, and all stages after last_stages (i.e. descendants and 278 ancestors on the stages DAG.) 279 """ 280 stages_d = {s.name: s for s in stages} 281 stage_names = list(stg.name for stg in stages) 282 lower_names = {s.lower() for s in stage_names} 283 284 for param, _stage_list in [ 285 ('first_stages', first_stages), 286 ('last_stages', last_stages), 287 ]: 288 for _s_name in _stage_list: 289 if _s_name.lower() not in lower_names: 290 raise WorkflowError( 291 f'Value in workflow/{param} "{_s_name}" must be a stage name ' 292 f"or a subset of stages from the available list: " 293 f'{", ".join(stage_names)}', 294 ) 295 296 if not (last_stages or first_stages): 297 return 298 299 # E.g. if our last_stages is CramQc, MtToEs would still run because it's in 300 # a different branch. So we want to collect all stages after first_stages 301 # and before last_stages in their respective branches, and mark as skipped 302 # everything in other branches. 303 first_stages_keeps: list[str] = first_stages[:] 304 last_stages_keeps: list[str] = last_stages[:] 305 306 for fs in first_stages: 307 for descendant in nx.descendants(graph, fs): 308 if not stages_d[descendant].skipped: 309 LOGGER.info( 310 f'Skipping stage {descendant} (precedes {fs} listed in first_stages)', 311 ) 312 stages_d[descendant].skipped = True 313 for grand_descendant in nx.descendants(graph, descendant): 314 if not stages_d[grand_descendant].assume_outputs_exist: 315 LOGGER.info( 316 f'Not checking expected outputs of not immediately ' 317 f'required stage {grand_descendant} (< {descendant} < {fs})', 318 ) 319 stages_d[grand_descendant].assume_outputs_exist = True 320 321 for ancestor in nx.ancestors(graph, fs): 322 first_stages_keeps.append(ancestor) 323 324 for ls in last_stages: 325 # ancestors of this last_stage 326 ancestors = nx.ancestors(graph, ls) 327 if any(anc in last_stages for anc in ancestors): 328 # a downstream stage is also in last_stages, so this is not yet 329 # a "real" last stage that we want to run 330 continue 331 for ancestor in ancestors: 332 if stages_d[ancestor].skipped: 333 continue # already skipped 334 LOGGER.info(f'Skipping stage {ancestor} (after last {ls})') 335 stages_d[ancestor].skipped = True 336 stages_d[ancestor].assume_outputs_exist = True 337 338 for ancestor in nx.descendants(graph, ls): 339 last_stages_keeps.append(ancestor) 340 341 for _stage in stages: 342 if _stage.name not in last_stages_keeps + first_stages_keeps: 343 _stage.skipped = True 344 _stage.assume_outputs_exist = True 345 346 @staticmethod 347 def _process_only_stages( 348 stages: list['Stage'], 349 graph: nx.DiGraph, 350 only_stages: list[str], 351 ): 352 if not only_stages: 353 return 354 355 stages_d = {s.name: s for s in stages} 356 stage_names = list(stg.name for stg in stages) 357 lower_names = {s.lower() for s in stage_names} 358 359 for s_name in only_stages: 360 if s_name.lower() not in lower_names: 361 raise WorkflowError( 362 f'Value in workflow/only_stages "{s_name}" must be a stage ' 363 f"name or a subset of stages from the available list: " 364 f'{", ".join(stage_names)}', 365 ) 366 367 # We want to run stages only appearing in only_stages, and check outputs of 368 # imediate predecessor stages, but skip everything else. 369 required_stages: set[str] = set() 370 for os in only_stages: 371 rs = nx.descendants_at_distance(graph, os, 1) 372 required_stages |= set(rs) 373 374 for stage in stages: 375 # Skip stage not in only_stages, and assume outputs exist... 376 if stage.name not in only_stages: 377 stage.skipped = True 378 stage.assume_outputs_exist = True 379 380 # ...unless stage is directly required by any stage in only_stages 381 for stage_name in required_stages: 382 stages_d[stage_name].assume_outputs_exist = False 383 384 def set_stages( 385 self, 386 requested_stages: list['StageDecorator'], 387 ): 388 """ 389 Iterate over stages and call their queue_for_cohort(cohort) methods; 390 through that, creates all Hail Batch jobs through Stage.queue_jobs(). 391 """ 392 # TOML options to configure stages: 393 skip_stages = get_config()['workflow'].get('skip_stages', []) 394 only_stages = get_config()['workflow'].get('only_stages', []) 395 first_stages = get_config()['workflow'].get('first_stages', []) 396 last_stages = get_config()['workflow'].get('last_stages', []) 397 398 # Only allow one of only_stages or first_stages/last_stages as they seem 399 # to be mutually exclusive. 400 if only_stages and (first_stages or last_stages or skip_stages): 401 raise WorkflowError( 402 "Workflow config parameter 'only_stages' is incompatible with " 403 + "'first_stages', 'last_stages' and/or 'skip_stages'", 404 ) 405 406 LOGGER.info( 407 f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}', 408 ) 409 LOGGER.info('Stages additional configuration:') 410 LOGGER.info(f' workflow/skip_stages: {skip_stages}') 411 LOGGER.info(f' workflow/only_stages: {only_stages}') 412 LOGGER.info(f' workflow/first_stages: {first_stages}') 413 LOGGER.info(f' workflow/last_stages: {last_stages}') 414 415 # Round 1: initialising stage objects. 416 _stages_d: dict[str, Stage] = {} 417 for cls in requested_stages: 418 if cls.__name__ in _stages_d: 419 continue 420 _stages_d[cls.__name__] = cls() 421 422 # Round 2: depth search to find implicit stages. 423 depth = 0 424 while True: # might require few iterations to resolve dependencies recursively 425 depth += 1 426 newly_implicitly_added_d = dict() 427 for stg in _stages_d.values(): 428 if stg.name in skip_stages: 429 stg.skipped = True 430 continue # not searching deeper 431 432 if only_stages and stg.name not in only_stages: 433 stg.skipped = True 434 435 # Iterate dependencies: 436 for reqcls in stg.required_stages_classes: 437 if reqcls.__name__ in _stages_d: # already added 438 continue 439 # Initialising and adding as explicit. 440 reqstg = reqcls() 441 newly_implicitly_added_d[reqstg.name] = reqstg 442 443 if newly_implicitly_added_d: 444 LOGGER.info( 445 f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}', 446 ) 447 _stages_d |= newly_implicitly_added_d 448 else: 449 # No new implicit stages added, so can stop the depth-search here 450 break 451 452 # Round 3: set "stage.required_stages" fields to each stage. 453 for stg in _stages_d.values(): 454 stg.required_stages = [ 455 _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d 456 ] 457 458 # Round 4: determining order of execution. 459 dag_node2nodes = dict() # building a DAG 460 for stg in _stages_d.values(): 461 dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) 462 dag = nx.DiGraph(dag_node2nodes) 463 try: 464 stage_names = list(reversed(list(nx.topological_sort(dag)))) 465 except nx.NetworkXUnfeasible: 466 LOGGER.error('Circular dependencies found between stages') 467 raise 468 469 LOGGER.info(f'Stages in order of execution:\n{stage_names}') 470 stages = [_stages_d[name] for name in stage_names] 471 472 # Round 5: applying workflow options first_stages and last_stages. 473 if first_stages or last_stages: 474 LOGGER.info('Applying workflow/first_stages and workflow/last_stages') 475 self._process_first_last_stages(stages, dag, first_stages, last_stages) 476 elif only_stages: 477 LOGGER.info('Applying workflow/only_stages') 478 self._process_only_stages(stages, dag, only_stages) 479 480 if not (final_set_of_stages := [s.name for s in stages if not s.skipped]): 481 raise WorkflowError('No stages to run') 482 483 LOGGER.info( 484 f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}', 485 ) 486 487 required_skipped_stages = [s for s in stages if s.skipped] 488 if required_skipped_stages: 489 LOGGER.info( 490 f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}', 491 ) 492 493 # Round 6: actually adding jobs from the stages. 494 if not self.dry_run: 495 inputs = get_multicohort() # Would communicate with metamist. 496 for i, stg in enumerate(stages): 497 LOGGER.info('*' * 60) 498 LOGGER.info(f'Stage #{i + 1}: {stg}') 499 # pipeline setup is now done in MultiCohort only 500 # the legacy version (input_datasets) is still supported 501 # that will create a MultiCohort with a single Cohort 502 if isinstance(inputs, MultiCohort): 503 stg.output_by_target = stg.queue_for_multicohort(inputs) 504 else: 505 raise WorkflowError(f'Unsupported input type: {inputs}') 506 if errors := self._process_stage_errors(stg.output_by_target): 507 raise WorkflowError( 508 f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors), 509 ) 510 511 LOGGER.info('') 512 513 else: 514 self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped] 515 LOGGER.info(f'Queued stages: {self.queued_stages}') 516 517 @staticmethod 518 def _process_stage_errors( 519 output_by_target: dict[str, Union['StageOutput', None]], 520 ) -> list[str]: 521 targets_by_error = defaultdict(list) 522 for target, output in output_by_target.items(): 523 if output and output.error_msg: 524 targets_by_error[output.error_msg].append(target) 525 return [f'{error}: {", ".join(target_ids)}' for error, target_ids in targets_by_error.items()]
42def path_walk(expected, collected: set | None = None) -> set[Path]: 43 """ 44 recursive walk of expected_out 45 if the object is iterable, walk it 46 this gets around the issue with nested lists and dicts 47 mainly around the use of Array outputs from Cromwell 48 49 Args: 50 expected (Any): any type of object containing Paths 51 collected (set): all collected paths so far 52 53 Returns: 54 a set of all collected Path nodes 55 56 Examples: 57 58 >>> path_walk({'a': {'b': {'c': Path('d')}}}) 59 {Path('d')} 60 >>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}}) 61 {Path('d'), Path('e')} 62 >>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}}) 63 {Path('b'), Path('g')} 64 """ 65 if collected is None: 66 collected = set() 67 68 if expected is None: 69 return collected 70 if isinstance(expected, dict): 71 for value in expected.values(): 72 collected.update(path_walk(value, collected)) 73 if isinstance(expected, list | set): 74 for value in expected: 75 collected.update(path_walk(value, collected)) 76 if isinstance(expected, str): 77 return collected 78 if isinstance(expected, Path): 79 if expected in collected: 80 raise ValueError(f'Duplicate path {expected} in expected_out') 81 collected.add(expected) 82 return collected
recursive walk of expected_out if the object is iterable, walk it this gets around the issue with nested lists and dicts mainly around the use of Array outputs from Cromwell
Args: expected (Any): any type of object containing Paths collected (set): all collected paths so far
Returns: a set of all collected Path nodes
Examples:
>>> path_walk({'a': {'b': {'c': Path('d')}}})
{Path('d')}
>>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}})
{Path('d'), Path('e')}
>>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}})
{Path('b'), Path('g')}
85class WorkflowError(Exception): 86 """ 87 Error raised by workflow and stage implementation. 88 """
Error raised by workflow and stage implementation.
91class Action(Enum): 92 """ 93 Indicates what a stage should do with a specific target. 94 """ 95 96 QUEUE = 1 97 SKIP = 2 98 REUSE = 3
Indicates what a stage should do with a specific target.
102def skip( 103 _fun: Optional['StageDecorator'] = None, 104 *, 105 reason: str | None = None, 106 assume_outputs_exist: bool = False, 107) -> Union['StageDecorator', Callable[..., 'StageDecorator']]: 108 """ 109 Decorator on top of `@stage` that sets the `self.skipped` field to True. 110 By default, expected outputs of a skipped stage will be checked, 111 unless `assume_outputs_exist` is True. 112 113 @skip 114 @stage 115 class MyStage1(SequencingGroupStage): 116 ... 117 118 @skip 119 @stage(assume_outputs_exist=True) 120 class MyStage2(SequencingGroupStage): 121 ... 122 """ 123 124 def decorator_stage(fun) -> 'StageDecorator': 125 """Implements decorator.""" 126 127 @functools.wraps(fun) 128 def wrapper_stage(*args, **kwargs) -> 'Stage': 129 """Decorator helper function.""" 130 s = fun(*args, **kwargs) 131 s.skipped = True 132 s.assume_outputs_exist = assume_outputs_exist 133 return s 134 135 return wrapper_stage 136 137 if _fun is None: 138 return decorator_stage 139 return decorator_stage(_fun)
Decorator on top of @stage
that sets the self.skipped
field to True.
By default, expected outputs of a skipped stage will be checked,
unless assume_outputs_exist
is True.
@skip @stage class MyStage1(SequencingGroupStage): ...
@skip @stage(assume_outputs_exist=True) class MyStage2(SequencingGroupStage): ...
162class Workflow: 163 """ 164 Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. 165 Responsible for orchestrating stages. 166 """ 167 168 def __init__( 169 self, 170 stages: list['StageDecorator'] | None = None, 171 dry_run: bool | None = None, 172 ): 173 if _workflow is not None: 174 raise ValueError( 175 'Workflow already initialised. Use get_workflow() to get the instance', 176 ) 177 178 self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run') 179 180 # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset? 181 analysis_dataset = get_config(True)['workflow']['dataset'] 182 name = get_config()['workflow'].get('name', analysis_dataset) 183 description = get_config()['workflow'].get('description', name) 184 self.name = slugify(name) 185 186 self._output_version: str | None = None 187 if output_version := get_config()['workflow'].get('output_version'): 188 self._output_version = slugify(output_version) 189 190 self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp() 191 192 # Description 193 if self._output_version: 194 description += f': output_version={self._output_version}' 195 description += f': run_timestamp={self.run_timestamp}' 196 if sequencing_type := get_config()['workflow'].get('sequencing_type'): 197 description += f' [{sequencing_type}]' 198 if not self.dry_run: 199 if ds_set := set(d.name for d in get_multicohort().get_datasets()): 200 description += ' ' + ', '.join(sorted(ds_set)) 201 reset_batch() 202 get_batch().name = description 203 204 self.status_reporter = None 205 if get_config()['workflow'].get('status_reporter') == 'metamist': 206 self.status_reporter = MetamistStatusReporter() 207 self._stages: list[StageDecorator] | None = stages 208 self.queued_stages: list[Stage] = [] 209 210 @property 211 def output_version(self) -> str: 212 return self._output_version or get_multicohort().alignment_inputs_hash() 213 214 @property 215 def analysis_prefix(self) -> Path: 216 return self._prefix(category='analysis') 217 218 @property 219 def tmp_prefix(self) -> Path: 220 return self._prefix(category='tmp') 221 222 @property 223 def web_prefix(self) -> Path: 224 return self._prefix(category='web') 225 226 @property 227 def prefix(self) -> Path: 228 return self._prefix() 229 230 def _prefix(self, category: str | None = None) -> Path: 231 """ 232 Prepare a unique path for the workflow with this name and this input data. 233 """ 234 return get_multicohort().analysis_dataset.prefix(category=category) / self.name / self.output_version 235 236 def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path: 237 """ 238 Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method 239 Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID 240 e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123" 241 242 Args: 243 cohort (Cohort): we pull the analysis dataset and name from this Cohort 244 category (str | None): sub-bucket for this project 245 246 Returns: 247 Path 248 """ 249 return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.name 250 251 def run( 252 self, 253 stages: list['StageDecorator'] | None = None, 254 wait: bool | None = False, 255 ): 256 """ 257 Resolve stages, add and submit Hail Batch jobs. 258 When `run_all_implicit_stages` is set, all required stages that were not defined 259 explicitly would still be executed. 260 """ 261 _stages = stages or self._stages 262 if not _stages: 263 raise WorkflowError('No stages added') 264 self.set_stages(_stages) 265 266 if not self.dry_run: 267 get_batch().run(wait=wait) 268 269 @staticmethod 270 def _process_first_last_stages( 271 stages: list['Stage'], 272 graph: nx.DiGraph, 273 first_stages: list[str], 274 last_stages: list[str], 275 ): 276 """ 277 Applying first_stages and last_stages config options. Would skip all stages 278 before first_stages, and all stages after last_stages (i.e. descendants and 279 ancestors on the stages DAG.) 280 """ 281 stages_d = {s.name: s for s in stages} 282 stage_names = list(stg.name for stg in stages) 283 lower_names = {s.lower() for s in stage_names} 284 285 for param, _stage_list in [ 286 ('first_stages', first_stages), 287 ('last_stages', last_stages), 288 ]: 289 for _s_name in _stage_list: 290 if _s_name.lower() not in lower_names: 291 raise WorkflowError( 292 f'Value in workflow/{param} "{_s_name}" must be a stage name ' 293 f"or a subset of stages from the available list: " 294 f'{", ".join(stage_names)}', 295 ) 296 297 if not (last_stages or first_stages): 298 return 299 300 # E.g. if our last_stages is CramQc, MtToEs would still run because it's in 301 # a different branch. So we want to collect all stages after first_stages 302 # and before last_stages in their respective branches, and mark as skipped 303 # everything in other branches. 304 first_stages_keeps: list[str] = first_stages[:] 305 last_stages_keeps: list[str] = last_stages[:] 306 307 for fs in first_stages: 308 for descendant in nx.descendants(graph, fs): 309 if not stages_d[descendant].skipped: 310 LOGGER.info( 311 f'Skipping stage {descendant} (precedes {fs} listed in first_stages)', 312 ) 313 stages_d[descendant].skipped = True 314 for grand_descendant in nx.descendants(graph, descendant): 315 if not stages_d[grand_descendant].assume_outputs_exist: 316 LOGGER.info( 317 f'Not checking expected outputs of not immediately ' 318 f'required stage {grand_descendant} (< {descendant} < {fs})', 319 ) 320 stages_d[grand_descendant].assume_outputs_exist = True 321 322 for ancestor in nx.ancestors(graph, fs): 323 first_stages_keeps.append(ancestor) 324 325 for ls in last_stages: 326 # ancestors of this last_stage 327 ancestors = nx.ancestors(graph, ls) 328 if any(anc in last_stages for anc in ancestors): 329 # a downstream stage is also in last_stages, so this is not yet 330 # a "real" last stage that we want to run 331 continue 332 for ancestor in ancestors: 333 if stages_d[ancestor].skipped: 334 continue # already skipped 335 LOGGER.info(f'Skipping stage {ancestor} (after last {ls})') 336 stages_d[ancestor].skipped = True 337 stages_d[ancestor].assume_outputs_exist = True 338 339 for ancestor in nx.descendants(graph, ls): 340 last_stages_keeps.append(ancestor) 341 342 for _stage in stages: 343 if _stage.name not in last_stages_keeps + first_stages_keeps: 344 _stage.skipped = True 345 _stage.assume_outputs_exist = True 346 347 @staticmethod 348 def _process_only_stages( 349 stages: list['Stage'], 350 graph: nx.DiGraph, 351 only_stages: list[str], 352 ): 353 if not only_stages: 354 return 355 356 stages_d = {s.name: s for s in stages} 357 stage_names = list(stg.name for stg in stages) 358 lower_names = {s.lower() for s in stage_names} 359 360 for s_name in only_stages: 361 if s_name.lower() not in lower_names: 362 raise WorkflowError( 363 f'Value in workflow/only_stages "{s_name}" must be a stage ' 364 f"name or a subset of stages from the available list: " 365 f'{", ".join(stage_names)}', 366 ) 367 368 # We want to run stages only appearing in only_stages, and check outputs of 369 # imediate predecessor stages, but skip everything else. 370 required_stages: set[str] = set() 371 for os in only_stages: 372 rs = nx.descendants_at_distance(graph, os, 1) 373 required_stages |= set(rs) 374 375 for stage in stages: 376 # Skip stage not in only_stages, and assume outputs exist... 377 if stage.name not in only_stages: 378 stage.skipped = True 379 stage.assume_outputs_exist = True 380 381 # ...unless stage is directly required by any stage in only_stages 382 for stage_name in required_stages: 383 stages_d[stage_name].assume_outputs_exist = False 384 385 def set_stages( 386 self, 387 requested_stages: list['StageDecorator'], 388 ): 389 """ 390 Iterate over stages and call their queue_for_cohort(cohort) methods; 391 through that, creates all Hail Batch jobs through Stage.queue_jobs(). 392 """ 393 # TOML options to configure stages: 394 skip_stages = get_config()['workflow'].get('skip_stages', []) 395 only_stages = get_config()['workflow'].get('only_stages', []) 396 first_stages = get_config()['workflow'].get('first_stages', []) 397 last_stages = get_config()['workflow'].get('last_stages', []) 398 399 # Only allow one of only_stages or first_stages/last_stages as they seem 400 # to be mutually exclusive. 401 if only_stages and (first_stages or last_stages or skip_stages): 402 raise WorkflowError( 403 "Workflow config parameter 'only_stages' is incompatible with " 404 + "'first_stages', 'last_stages' and/or 'skip_stages'", 405 ) 406 407 LOGGER.info( 408 f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}', 409 ) 410 LOGGER.info('Stages additional configuration:') 411 LOGGER.info(f' workflow/skip_stages: {skip_stages}') 412 LOGGER.info(f' workflow/only_stages: {only_stages}') 413 LOGGER.info(f' workflow/first_stages: {first_stages}') 414 LOGGER.info(f' workflow/last_stages: {last_stages}') 415 416 # Round 1: initialising stage objects. 417 _stages_d: dict[str, Stage] = {} 418 for cls in requested_stages: 419 if cls.__name__ in _stages_d: 420 continue 421 _stages_d[cls.__name__] = cls() 422 423 # Round 2: depth search to find implicit stages. 424 depth = 0 425 while True: # might require few iterations to resolve dependencies recursively 426 depth += 1 427 newly_implicitly_added_d = dict() 428 for stg in _stages_d.values(): 429 if stg.name in skip_stages: 430 stg.skipped = True 431 continue # not searching deeper 432 433 if only_stages and stg.name not in only_stages: 434 stg.skipped = True 435 436 # Iterate dependencies: 437 for reqcls in stg.required_stages_classes: 438 if reqcls.__name__ in _stages_d: # already added 439 continue 440 # Initialising and adding as explicit. 441 reqstg = reqcls() 442 newly_implicitly_added_d[reqstg.name] = reqstg 443 444 if newly_implicitly_added_d: 445 LOGGER.info( 446 f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}', 447 ) 448 _stages_d |= newly_implicitly_added_d 449 else: 450 # No new implicit stages added, so can stop the depth-search here 451 break 452 453 # Round 3: set "stage.required_stages" fields to each stage. 454 for stg in _stages_d.values(): 455 stg.required_stages = [ 456 _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d 457 ] 458 459 # Round 4: determining order of execution. 460 dag_node2nodes = dict() # building a DAG 461 for stg in _stages_d.values(): 462 dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) 463 dag = nx.DiGraph(dag_node2nodes) 464 try: 465 stage_names = list(reversed(list(nx.topological_sort(dag)))) 466 except nx.NetworkXUnfeasible: 467 LOGGER.error('Circular dependencies found between stages') 468 raise 469 470 LOGGER.info(f'Stages in order of execution:\n{stage_names}') 471 stages = [_stages_d[name] for name in stage_names] 472 473 # Round 5: applying workflow options first_stages and last_stages. 474 if first_stages or last_stages: 475 LOGGER.info('Applying workflow/first_stages and workflow/last_stages') 476 self._process_first_last_stages(stages, dag, first_stages, last_stages) 477 elif only_stages: 478 LOGGER.info('Applying workflow/only_stages') 479 self._process_only_stages(stages, dag, only_stages) 480 481 if not (final_set_of_stages := [s.name for s in stages if not s.skipped]): 482 raise WorkflowError('No stages to run') 483 484 LOGGER.info( 485 f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}', 486 ) 487 488 required_skipped_stages = [s for s in stages if s.skipped] 489 if required_skipped_stages: 490 LOGGER.info( 491 f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}', 492 ) 493 494 # Round 6: actually adding jobs from the stages. 495 if not self.dry_run: 496 inputs = get_multicohort() # Would communicate with metamist. 497 for i, stg in enumerate(stages): 498 LOGGER.info('*' * 60) 499 LOGGER.info(f'Stage #{i + 1}: {stg}') 500 # pipeline setup is now done in MultiCohort only 501 # the legacy version (input_datasets) is still supported 502 # that will create a MultiCohort with a single Cohort 503 if isinstance(inputs, MultiCohort): 504 stg.output_by_target = stg.queue_for_multicohort(inputs) 505 else: 506 raise WorkflowError(f'Unsupported input type: {inputs}') 507 if errors := self._process_stage_errors(stg.output_by_target): 508 raise WorkflowError( 509 f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors), 510 ) 511 512 LOGGER.info('') 513 514 else: 515 self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped] 516 LOGGER.info(f'Queued stages: {self.queued_stages}') 517 518 @staticmethod 519 def _process_stage_errors( 520 output_by_target: dict[str, Union['StageOutput', None]], 521 ) -> list[str]: 522 targets_by_error = defaultdict(list) 523 for target, output in output_by_target.items(): 524 if output and output.error_msg: 525 targets_by_error[output.error_msg].append(target) 526 return [f'{error}: {", ".join(target_ids)}' for error, target_ids in targets_by_error.items()]
Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. Responsible for orchestrating stages.
168 def __init__( 169 self, 170 stages: list['StageDecorator'] | None = None, 171 dry_run: bool | None = None, 172 ): 173 if _workflow is not None: 174 raise ValueError( 175 'Workflow already initialised. Use get_workflow() to get the instance', 176 ) 177 178 self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run') 179 180 # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset? 181 analysis_dataset = get_config(True)['workflow']['dataset'] 182 name = get_config()['workflow'].get('name', analysis_dataset) 183 description = get_config()['workflow'].get('description', name) 184 self.name = slugify(name) 185 186 self._output_version: str | None = None 187 if output_version := get_config()['workflow'].get('output_version'): 188 self._output_version = slugify(output_version) 189 190 self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp() 191 192 # Description 193 if self._output_version: 194 description += f': output_version={self._output_version}' 195 description += f': run_timestamp={self.run_timestamp}' 196 if sequencing_type := get_config()['workflow'].get('sequencing_type'): 197 description += f' [{sequencing_type}]' 198 if not self.dry_run: 199 if ds_set := set(d.name for d in get_multicohort().get_datasets()): 200 description += ' ' + ', '.join(sorted(ds_set)) 201 reset_batch() 202 get_batch().name = description 203 204 self.status_reporter = None 205 if get_config()['workflow'].get('status_reporter') == 'metamist': 206 self.status_reporter = MetamistStatusReporter() 207 self._stages: list[StageDecorator] | None = stages 208 self.queued_stages: list[Stage] = []
236 def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path: 237 """ 238 Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method 239 Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID 240 e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123" 241 242 Args: 243 cohort (Cohort): we pull the analysis dataset and name from this Cohort 244 category (str | None): sub-bucket for this project 245 246 Returns: 247 Path 248 """ 249 return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.name
Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"
Args: cohort (Cohort): we pull the analysis dataset and name from this Cohort category (str | None): sub-bucket for this project
Returns: Path
251 def run( 252 self, 253 stages: list['StageDecorator'] | None = None, 254 wait: bool | None = False, 255 ): 256 """ 257 Resolve stages, add and submit Hail Batch jobs. 258 When `run_all_implicit_stages` is set, all required stages that were not defined 259 explicitly would still be executed. 260 """ 261 _stages = stages or self._stages 262 if not _stages: 263 raise WorkflowError('No stages added') 264 self.set_stages(_stages) 265 266 if not self.dry_run: 267 get_batch().run(wait=wait)
Resolve stages, add and submit Hail Batch jobs.
When run_all_implicit_stages
is set, all required stages that were not defined
explicitly would still be executed.
385 def set_stages( 386 self, 387 requested_stages: list['StageDecorator'], 388 ): 389 """ 390 Iterate over stages and call their queue_for_cohort(cohort) methods; 391 through that, creates all Hail Batch jobs through Stage.queue_jobs(). 392 """ 393 # TOML options to configure stages: 394 skip_stages = get_config()['workflow'].get('skip_stages', []) 395 only_stages = get_config()['workflow'].get('only_stages', []) 396 first_stages = get_config()['workflow'].get('first_stages', []) 397 last_stages = get_config()['workflow'].get('last_stages', []) 398 399 # Only allow one of only_stages or first_stages/last_stages as they seem 400 # to be mutually exclusive. 401 if only_stages and (first_stages or last_stages or skip_stages): 402 raise WorkflowError( 403 "Workflow config parameter 'only_stages' is incompatible with " 404 + "'first_stages', 'last_stages' and/or 'skip_stages'", 405 ) 406 407 LOGGER.info( 408 f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}', 409 ) 410 LOGGER.info('Stages additional configuration:') 411 LOGGER.info(f' workflow/skip_stages: {skip_stages}') 412 LOGGER.info(f' workflow/only_stages: {only_stages}') 413 LOGGER.info(f' workflow/first_stages: {first_stages}') 414 LOGGER.info(f' workflow/last_stages: {last_stages}') 415 416 # Round 1: initialising stage objects. 417 _stages_d: dict[str, Stage] = {} 418 for cls in requested_stages: 419 if cls.__name__ in _stages_d: 420 continue 421 _stages_d[cls.__name__] = cls() 422 423 # Round 2: depth search to find implicit stages. 424 depth = 0 425 while True: # might require few iterations to resolve dependencies recursively 426 depth += 1 427 newly_implicitly_added_d = dict() 428 for stg in _stages_d.values(): 429 if stg.name in skip_stages: 430 stg.skipped = True 431 continue # not searching deeper 432 433 if only_stages and stg.name not in only_stages: 434 stg.skipped = True 435 436 # Iterate dependencies: 437 for reqcls in stg.required_stages_classes: 438 if reqcls.__name__ in _stages_d: # already added 439 continue 440 # Initialising and adding as explicit. 441 reqstg = reqcls() 442 newly_implicitly_added_d[reqstg.name] = reqstg 443 444 if newly_implicitly_added_d: 445 LOGGER.info( 446 f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}', 447 ) 448 _stages_d |= newly_implicitly_added_d 449 else: 450 # No new implicit stages added, so can stop the depth-search here 451 break 452 453 # Round 3: set "stage.required_stages" fields to each stage. 454 for stg in _stages_d.values(): 455 stg.required_stages = [ 456 _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d 457 ] 458 459 # Round 4: determining order of execution. 460 dag_node2nodes = dict() # building a DAG 461 for stg in _stages_d.values(): 462 dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) 463 dag = nx.DiGraph(dag_node2nodes) 464 try: 465 stage_names = list(reversed(list(nx.topological_sort(dag)))) 466 except nx.NetworkXUnfeasible: 467 LOGGER.error('Circular dependencies found between stages') 468 raise 469 470 LOGGER.info(f'Stages in order of execution:\n{stage_names}') 471 stages = [_stages_d[name] for name in stage_names] 472 473 # Round 5: applying workflow options first_stages and last_stages. 474 if first_stages or last_stages: 475 LOGGER.info('Applying workflow/first_stages and workflow/last_stages') 476 self._process_first_last_stages(stages, dag, first_stages, last_stages) 477 elif only_stages: 478 LOGGER.info('Applying workflow/only_stages') 479 self._process_only_stages(stages, dag, only_stages) 480 481 if not (final_set_of_stages := [s.name for s in stages if not s.skipped]): 482 raise WorkflowError('No stages to run') 483 484 LOGGER.info( 485 f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}', 486 ) 487 488 required_skipped_stages = [s for s in stages if s.skipped] 489 if required_skipped_stages: 490 LOGGER.info( 491 f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}', 492 ) 493 494 # Round 6: actually adding jobs from the stages. 495 if not self.dry_run: 496 inputs = get_multicohort() # Would communicate with metamist. 497 for i, stg in enumerate(stages): 498 LOGGER.info('*' * 60) 499 LOGGER.info(f'Stage #{i + 1}: {stg}') 500 # pipeline setup is now done in MultiCohort only 501 # the legacy version (input_datasets) is still supported 502 # that will create a MultiCohort with a single Cohort 503 if isinstance(inputs, MultiCohort): 504 stg.output_by_target = stg.queue_for_multicohort(inputs) 505 else: 506 raise WorkflowError(f'Unsupported input type: {inputs}') 507 if errors := self._process_stage_errors(stg.output_by_target): 508 raise WorkflowError( 509 f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors), 510 ) 511 512 LOGGER.info('') 513 514 else: 515 self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped] 516 LOGGER.info(f'Queued stages: {self.queued_stages}')
Iterate over stages and call their queue_for_cohort(cohort) methods; through that, creates all Hail Batch jobs through Stage.queue_jobs().