cpg_flow.workflow
Provides a Workflow
class and a @stage
decorator that allow to define workflows
in a declarative fashion.
A Stage
object is responsible for creating Hail Batch jobs and declaring outputs
(files or metamist analysis objects) that are expected to be produced. Each stage
acts on a Target
, which can be of the following:
* SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample)
* Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care)
* Cohort - a stratification of SGs in this analysis by Metamist CustomCohort
* MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort
A Workflow
object plugs stages together by resolving dependencies between different levels accordingly. Stages are
defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in
main.py, which provides a way to choose a workflow using a CLI argument.
1""" 2Provides a `Workflow` class and a `@stage` decorator that allow to define workflows 3in a declarative fashion. 4 5A `Stage` object is responsible for creating Hail Batch jobs and declaring outputs 6(files or metamist analysis objects) that are expected to be produced. Each stage 7acts on a `Target`, which can be of the following: 8 9 * SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample) 10 * Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care) 11 * Cohort - a stratification of SGs in this analysis by Metamist CustomCohort 12 * MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort 13 14A `Workflow` object plugs stages together by resolving dependencies between different levels accordingly. Stages are 15defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in 16main.py, which provides a way to choose a workflow using a CLI argument. 17""" 18 19import functools 20from collections import defaultdict 21from collections.abc import Callable 22from enum import Enum 23from typing import TYPE_CHECKING, Optional, Union 24 25import networkx as nx 26import plotly.io as pio 27 28from cpg_flow.inputs import get_multicohort 29from cpg_flow.show_workflow.graph import GraphPlot 30from cpg_flow.status import MetamistStatusReporter 31from cpg_flow.targets import Cohort, MultiCohort 32from cpg_flow.utils import get_logger, slugify, timestamp, write_to_gcs_bucket 33from cpg_utils import Path 34from cpg_utils.config import get_config 35from cpg_utils.hail_batch import get_batch, reset_batch 36 37LOGGER = get_logger(__name__) 38URL_BASENAME = 'https://{access_level}-web.populationgenomics.org.au/{name}/' 39 40if TYPE_CHECKING: 41 from cpg_flow.stage import Stage, StageDecorator, StageOutput 42 43 44def path_walk(expected, collected: set | None = None) -> set[Path]: 45 """ 46 recursive walk of expected_out 47 if the object is iterable, walk it 48 this gets around the issue with nested lists and dicts 49 mainly around the use of Array outputs from Cromwell 50 51 Args: 52 expected (Any): any type of object containing Paths 53 collected (set): all collected paths so far 54 55 Returns: 56 a set of all collected Path nodes 57 58 Examples: 59 60 >>> path_walk({'a': {'b': {'c': Path('d')}}}) 61 {Path('d')} 62 >>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}}) 63 {Path('d'), Path('e')} 64 >>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}}) 65 {Path('b'), Path('g')} 66 """ 67 if collected is None: 68 collected = set() 69 70 if expected is None: 71 return collected 72 if isinstance(expected, dict): 73 for value in expected.values(): 74 collected.update(path_walk(value, collected)) 75 if isinstance(expected, list | set): 76 for value in expected: 77 collected.update(path_walk(value, collected)) 78 if isinstance(expected, str): 79 return collected 80 if isinstance(expected, Path): 81 if expected in collected: 82 raise ValueError(f'Duplicate path {expected} in expected_out') 83 collected.add(expected) 84 return collected 85 86 87class WorkflowError(Exception): 88 """ 89 Error raised by workflow and stage implementation. 90 """ 91 92 93class Action(Enum): 94 """ 95 Indicates what a stage should do with a specific target. 96 """ 97 98 QUEUE = 1 99 SKIP = 2 100 REUSE = 3 101 102 103# noinspection PyUnusedLocal 104def skip( 105 _fun: Optional['StageDecorator'] = None, 106 *, 107 reason: str | None = None, 108 assume_outputs_exist: bool = False, 109) -> Union['StageDecorator', Callable[..., 'StageDecorator']]: 110 """ 111 Decorator on top of `@stage` that sets the `self.skipped` field to True. 112 By default, expected outputs of a skipped stage will be checked, 113 unless `assume_outputs_exist` is True. 114 115 @skip 116 @stage 117 class MyStage1(SequencingGroupStage): 118 ... 119 120 @skip 121 @stage(assume_outputs_exist=True) 122 class MyStage2(SequencingGroupStage): 123 ... 124 """ 125 126 def decorator_stage(fun) -> 'StageDecorator': 127 """Implements decorator.""" 128 129 @functools.wraps(fun) 130 def wrapper_stage(*args, **kwargs) -> 'Stage': 131 """Decorator helper function.""" 132 s = fun(*args, **kwargs) 133 s.skipped = True 134 s.assume_outputs_exist = assume_outputs_exist 135 return s 136 137 return wrapper_stage 138 139 if _fun is None: 140 return decorator_stage 141 return decorator_stage(_fun) 142 143 144_workflow: Optional['Workflow'] = None 145 146 147def get_workflow(dry_run: bool = False) -> 'Workflow': 148 global _workflow 149 if _workflow is None: 150 _workflow = Workflow(dry_run=dry_run) 151 return _workflow 152 153 154def run_workflow( 155 stages: list['StageDecorator'] | None = None, 156 wait: bool | None = False, 157 dry_run: bool = False, 158) -> 'Workflow': 159 wfl = get_workflow(dry_run=dry_run) 160 wfl.run(stages=stages, wait=wait) 161 return wfl 162 163 164class Workflow: 165 """ 166 Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. 167 Responsible for orchestrating stages. 168 """ 169 170 def __init__( 171 self, 172 stages: list['StageDecorator'] | None = None, 173 dry_run: bool | None = None, 174 ): 175 if _workflow is not None: 176 raise ValueError( 177 'Workflow already initialised. Use get_workflow() to get the instance', 178 ) 179 180 self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run') 181 self.show_workflow = get_config()['workflow'].get('show_workflow', False) 182 self.access_level = get_config()['workflow'].get('access_level', 'test') 183 184 # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset? 185 analysis_dataset = get_config(True)['workflow']['dataset'] 186 name = get_config()['workflow'].get('name', analysis_dataset) 187 description = get_config()['workflow'].get('description', name) 188 self.name = slugify(name) 189 190 self._output_version: str | None = None 191 if output_version := get_config()['workflow'].get('output_version'): 192 self._output_version = slugify(output_version) 193 194 self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp() 195 196 # Description 197 if self._output_version: 198 description += f': output_version={self._output_version}' 199 description += f': run_timestamp={self.run_timestamp}' 200 if sequencing_type := get_config()['workflow'].get('sequencing_type'): 201 description += f' [{sequencing_type}]' 202 if not self.dry_run: 203 if ds_set := set(d.name for d in get_multicohort().get_datasets()): 204 description += ' ' + ', '.join(sorted(ds_set)) 205 reset_batch() 206 get_batch().name = description 207 208 self.status_reporter = None 209 if get_config()['workflow'].get('status_reporter') == 'metamist': 210 self.status_reporter = MetamistStatusReporter() 211 self._stages: list[StageDecorator] | None = stages 212 self.queued_stages: list[Stage] = [] 213 214 @property 215 def output_version(self) -> str: 216 return self._output_version or get_multicohort().get_alignment_inputs_hash() 217 218 @property 219 def analysis_prefix(self) -> Path: 220 return self._prefix(category='analysis') 221 222 @property 223 def tmp_prefix(self) -> Path: 224 return self._prefix(category='tmp') 225 226 @property 227 def web_prefix(self) -> Path: 228 return self._prefix(category='web') 229 230 @property 231 def prefix(self) -> Path: 232 return self._prefix() 233 234 def _prefix(self, category: str | None = None) -> Path: 235 """ 236 Prepare a unique path for the workflow with this name and this input data. 237 """ 238 return get_multicohort().analysis_dataset.prefix(category=category) / self.name / self.output_version 239 240 def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path: 241 """ 242 Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method 243 Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID 244 e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123" 245 246 Args: 247 cohort (Cohort): we pull the analysis dataset and id from this Cohort 248 category (str | None): sub-bucket for this project 249 250 Returns: 251 Path 252 """ 253 return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.id 254 255 def run( 256 self, 257 stages: list['StageDecorator'] | None = None, 258 wait: bool | None = False, 259 ): 260 """ 261 Resolve stages, add and submit Hail Batch jobs. 262 When `run_all_implicit_stages` is set, all required stages that were not defined 263 explicitly would still be executed. 264 """ 265 _stages = stages or self._stages 266 if not _stages: 267 raise WorkflowError('No stages added') 268 self.set_stages(_stages) 269 270 if not self.dry_run: 271 get_batch().run(wait=wait) 272 273 @staticmethod 274 def _process_first_last_stages( 275 stages: list['Stage'], 276 graph: nx.DiGraph, 277 first_stages: list[str], 278 last_stages: list[str], 279 ): 280 """ 281 Applying first_stages and last_stages config options. Would skip all stages 282 before first_stages, and all stages after last_stages (i.e. descendants and 283 ancestors on the stages DAG.) 284 """ 285 stages_d = {s.name: s for s in stages} 286 stage_names = list(stg.name for stg in stages) 287 lower_names = {s.lower() for s in stage_names} 288 289 for param, _stage_list in [ 290 ('first_stages', first_stages), 291 ('last_stages', last_stages), 292 ]: 293 for _s_name in _stage_list: 294 if _s_name.lower() not in lower_names: 295 raise WorkflowError( 296 f'Value in workflow/{param} "{_s_name}" must be a stage name ' 297 f"or a subset of stages from the available list: " 298 f'{", ".join(stage_names)}', 299 ) 300 301 if not (last_stages or first_stages): 302 return 303 304 # E.g. if our last_stages is CramQc, MtToEs would still run because it's in 305 # a different branch. So we want to collect all stages after first_stages 306 # and before last_stages in their respective branches, and mark as skipped 307 # everything in other branches. 308 first_stages_keeps: list[str] = first_stages[:] 309 last_stages_keeps: list[str] = last_stages[:] 310 311 for fs in first_stages: 312 for descendant in nx.descendants(graph, fs): 313 if not stages_d[descendant].skipped: 314 LOGGER.info( 315 f'Skipping stage {descendant} (precedes {fs} listed in first_stages)', 316 ) 317 stages_d[descendant].skipped = True 318 for grand_descendant in nx.descendants(graph, descendant): 319 if not stages_d[grand_descendant].assume_outputs_exist: 320 LOGGER.info( 321 f'Not checking expected outputs of not immediately ' 322 f'required stage {grand_descendant} (< {descendant} < {fs})', 323 ) 324 stages_d[grand_descendant].assume_outputs_exist = True 325 326 for ancestor in nx.ancestors(graph, fs): 327 first_stages_keeps.append(ancestor) 328 329 for ls in last_stages: 330 # ancestors of this last_stage 331 ancestors = nx.ancestors(graph, ls) 332 if any(anc in last_stages for anc in ancestors): 333 # a downstream stage is also in last_stages, so this is not yet 334 # a "real" last stage that we want to run 335 continue 336 for ancestor in ancestors: 337 if stages_d[ancestor].skipped: 338 continue # already skipped 339 LOGGER.info(f'Skipping stage {ancestor} (after last {ls})') 340 stages_d[ancestor].skipped = True 341 stages_d[ancestor].assume_outputs_exist = True 342 343 for ancestor in nx.descendants(graph, ls): 344 last_stages_keeps.append(ancestor) 345 346 for _stage in stages: 347 if _stage.name not in last_stages_keeps + first_stages_keeps: 348 _stage.skipped = True 349 _stage.assume_outputs_exist = True 350 351 for stage in stages: 352 if stage.skipped: 353 graph.nodes[stage.name]['skipped'] = True 354 355 @staticmethod 356 def _process_only_stages( 357 stages: list['Stage'], 358 graph: nx.DiGraph, 359 only_stages: list[str], 360 ): 361 if not only_stages: 362 return 363 364 stages_d = {s.name: s for s in stages} 365 stage_names = list(stg.name for stg in stages) 366 lower_names = {s.lower() for s in stage_names} 367 368 for s_name in only_stages: 369 if s_name.lower() not in lower_names: 370 raise WorkflowError( 371 f'Value in workflow/only_stages "{s_name}" must be a stage ' 372 f"name or a subset of stages from the available list: " 373 f'{", ".join(stage_names)}', 374 ) 375 376 # We want to run stages only appearing in only_stages, and check outputs of 377 # imediate predecessor stages, but skip everything else. 378 required_stages: set[str] = set() 379 for os in only_stages: 380 rs = nx.descendants_at_distance(graph, os, 1) 381 required_stages |= set(rs) 382 383 for stage in stages: 384 # Skip stage not in only_stages, and assume outputs exist... 385 if stage.name not in only_stages: 386 stage.skipped = True 387 stage.assume_outputs_exist = True 388 389 # ...unless stage is directly required by any stage in only_stages 390 for stage_name in required_stages: 391 stages_d[stage_name].assume_outputs_exist = False 392 393 def set_stages( 394 self, 395 requested_stages: list['StageDecorator'], 396 ): 397 """ 398 Iterate over stages and call their queue_for_cohort(cohort) methods; 399 through that, creates all Hail Batch jobs through Stage.queue_jobs(). 400 """ 401 # TOML options to configure stages: 402 skip_stages = get_config()['workflow'].get('skip_stages', []) 403 only_stages = get_config()['workflow'].get('only_stages', []) 404 first_stages = get_config()['workflow'].get('first_stages', []) 405 last_stages = get_config()['workflow'].get('last_stages', []) 406 407 # Only allow one of only_stages or first_stages/last_stages as they seem 408 # to be mutually exclusive. 409 if only_stages and (first_stages or last_stages or skip_stages): 410 raise WorkflowError( 411 "Workflow config parameter 'only_stages' is incompatible with " 412 + "'first_stages', 'last_stages' and/or 'skip_stages'", 413 ) 414 415 LOGGER.info( 416 f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}', 417 ) 418 LOGGER.info('Stages additional configuration:') 419 LOGGER.info(f' workflow/skip_stages: {skip_stages}') 420 LOGGER.info(f' workflow/only_stages: {only_stages}') 421 LOGGER.info(f' workflow/first_stages: {first_stages}') 422 LOGGER.info(f' workflow/last_stages: {last_stages}') 423 424 # Round 1: initialising stage objects. 425 _stages_d: dict[str, Stage] = {} 426 for cls in requested_stages: 427 if cls.__name__ in _stages_d: 428 continue 429 _stages_d[cls.__name__] = cls() 430 431 # Round 2: depth search to find implicit stages. 432 depth = 0 433 while True: # might require few iterations to resolve dependencies recursively 434 depth += 1 435 newly_implicitly_added_d = dict() 436 for stg in _stages_d.values(): 437 if stg.name in skip_stages: 438 stg.skipped = True 439 continue # not searching deeper 440 441 if only_stages and stg.name not in only_stages: 442 stg.skipped = True 443 444 # Iterate dependencies: 445 for reqcls in stg.required_stages_classes: 446 if reqcls.__name__ in _stages_d: # already added 447 continue 448 # Initialising and adding as explicit. 449 reqstg = reqcls() 450 newly_implicitly_added_d[reqstg.name] = reqstg 451 452 if newly_implicitly_added_d: 453 LOGGER.info( 454 f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}', 455 ) 456 _stages_d |= newly_implicitly_added_d 457 else: 458 # No new implicit stages added, so can stop the depth-search here 459 break 460 461 # Round 3: set "stage.required_stages" fields to each stage. 462 for stg in _stages_d.values(): 463 stg.required_stages = [ 464 _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d 465 ] 466 467 # Round 4: determining order of execution. 468 dag_node2nodes = dict() # building a DAG 469 for stg in _stages_d.values(): 470 dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) 471 dag = nx.DiGraph(dag_node2nodes) 472 473 try: 474 stage_names = list(reversed(list(nx.topological_sort(dag)))) 475 except nx.NetworkXUnfeasible: 476 LOGGER.error('Circular dependencies found between stages') 477 raise 478 479 LOGGER.info(f'Stages in order of execution:\n{stage_names}') 480 stages = [_stages_d[name] for name in stage_names] 481 482 # Set order attribute to stages 483 nx.set_node_attributes(dag, {s.name: num for num, s in enumerate(stages)}, name='order') 484 485 # Update dag with the skipped attribute so it can be updated in self._process_first_last_stages 486 nx.set_node_attributes(dag, False, name='skipped') 487 488 # Round 5: applying workflow options first_stages and last_stages. 489 if first_stages or last_stages: 490 LOGGER.info('Applying workflow/first_stages and workflow/last_stages') 491 self._process_first_last_stages(stages, dag, first_stages, last_stages) 492 elif only_stages: 493 LOGGER.info('Applying workflow/only_stages') 494 self._process_only_stages(stages, dag, only_stages) 495 496 if not (final_set_of_stages := [s.name for s in stages if not s.skipped]): 497 raise WorkflowError('No stages to run') 498 499 LOGGER.info( 500 f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}', 501 ) 502 503 required_skipped_stages = [s for s in stages if s.skipped] 504 if required_skipped_stages: 505 LOGGER.info( 506 f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}', 507 ) 508 509 # Round 6: actually adding jobs from the stages. 510 if not self.dry_run: 511 inputs = get_multicohort() # Would communicate with metamist. 512 for i, stg in enumerate(stages): 513 LOGGER.info('*' * 60) 514 LOGGER.info(f'Stage #{i + 1}: {stg}') 515 # pipeline setup is now done in MultiCohort only 516 # the legacy version (input_datasets) is still supported 517 # that will create a MultiCohort with a single Cohort 518 if isinstance(inputs, MultiCohort): 519 stg.output_by_target = stg.queue_for_multicohort(inputs) 520 else: 521 raise WorkflowError(f'Unsupported input type: {inputs}') 522 if errors := self._process_stage_errors(stg.output_by_target): 523 raise WorkflowError( 524 f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors), 525 ) 526 527 LOGGER.info('') 528 529 else: 530 self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped] 531 LOGGER.info(f'Queued stages: {self.queued_stages}') 532 533 # Round 7: show the workflow 534 535 def format_meta(attr: list): 536 return {s: s in attr for s in dag.nodes} 537 538 # First add remaining metadata 539 nx.set_node_attributes(dag, format_meta(skip_stages), name='skip_stages') 540 nx.set_node_attributes(dag, format_meta(only_stages), name='only_stages') 541 nx.set_node_attributes(dag, format_meta(first_stages), name='first_stages') 542 nx.set_node_attributes(dag, format_meta(last_stages), name='last_stages') 543 544 if self.show_workflow: 545 gp = GraphPlot(dag, title='Full Workflow Graph') 546 547 # Removed skipped steps for simple graph 548 all_nodes = list(dag.nodes) 549 _ = [dag.remove_node(n) for n in all_nodes if dag.nodes[n]['skipped']] 550 gp2 = GraphPlot(dag, title='Sub-Workflow Graph') 551 552 fig = gp + gp2 553 554 # Show the figure 555 fig.show() 556 557 # If we have a web bucket path 558 try: 559 if web_prefix := self.web_prefix: 560 html_path = web_prefix / f'{self.name}_workflow.html' 561 if str(html_path).startswith('gs:/'): 562 html_file = pio.to_html(fig, full_html=True) 563 _, file_path = write_to_gcs_bucket(html_file, html_path) 564 url = URL_BASENAME.format(access_level=self.access_level, name=self.name) + str(file_path) 565 566 LOGGER.info(f'Link to the graph: {url}') 567 else: 568 pio.write_html(fig, file=str(html_path), auto_open=False) 569 570 LOGGER.info(f'Workflow graph saved to {html_path}') 571 572 except ConnectionError as e: 573 LOGGER.error(f'Failed to save workflow graph: {e}') 574 575 @staticmethod 576 def _process_stage_errors( 577 output_by_target: dict[str, Union['StageOutput', None]], 578 ) -> list[str]: 579 targets_by_error = defaultdict(list) 580 for target, output in output_by_target.items(): 581 if output and output.error_msg: 582 targets_by_error[output.error_msg].append(target) 583 return [f'{error}: {", ".join(target_ids)}' for error, target_ids in targets_by_error.items()]
45def path_walk(expected, collected: set | None = None) -> set[Path]: 46 """ 47 recursive walk of expected_out 48 if the object is iterable, walk it 49 this gets around the issue with nested lists and dicts 50 mainly around the use of Array outputs from Cromwell 51 52 Args: 53 expected (Any): any type of object containing Paths 54 collected (set): all collected paths so far 55 56 Returns: 57 a set of all collected Path nodes 58 59 Examples: 60 61 >>> path_walk({'a': {'b': {'c': Path('d')}}}) 62 {Path('d')} 63 >>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}}) 64 {Path('d'), Path('e')} 65 >>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}}) 66 {Path('b'), Path('g')} 67 """ 68 if collected is None: 69 collected = set() 70 71 if expected is None: 72 return collected 73 if isinstance(expected, dict): 74 for value in expected.values(): 75 collected.update(path_walk(value, collected)) 76 if isinstance(expected, list | set): 77 for value in expected: 78 collected.update(path_walk(value, collected)) 79 if isinstance(expected, str): 80 return collected 81 if isinstance(expected, Path): 82 if expected in collected: 83 raise ValueError(f'Duplicate path {expected} in expected_out') 84 collected.add(expected) 85 return collected
recursive walk of expected_out if the object is iterable, walk it this gets around the issue with nested lists and dicts mainly around the use of Array outputs from Cromwell
Args: expected (Any): any type of object containing Paths collected (set): all collected paths so far
Returns: a set of all collected Path nodes
Examples:
>>> path_walk({'a': {'b': {'c': Path('d')}}})
{Path('d')}
>>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}})
{Path('d'), Path('e')}
>>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}})
{Path('b'), Path('g')}
88class WorkflowError(Exception): 89 """ 90 Error raised by workflow and stage implementation. 91 """
Error raised by workflow and stage implementation.
94class Action(Enum): 95 """ 96 Indicates what a stage should do with a specific target. 97 """ 98 99 QUEUE = 1 100 SKIP = 2 101 REUSE = 3
Indicates what a stage should do with a specific target.
105def skip( 106 _fun: Optional['StageDecorator'] = None, 107 *, 108 reason: str | None = None, 109 assume_outputs_exist: bool = False, 110) -> Union['StageDecorator', Callable[..., 'StageDecorator']]: 111 """ 112 Decorator on top of `@stage` that sets the `self.skipped` field to True. 113 By default, expected outputs of a skipped stage will be checked, 114 unless `assume_outputs_exist` is True. 115 116 @skip 117 @stage 118 class MyStage1(SequencingGroupStage): 119 ... 120 121 @skip 122 @stage(assume_outputs_exist=True) 123 class MyStage2(SequencingGroupStage): 124 ... 125 """ 126 127 def decorator_stage(fun) -> 'StageDecorator': 128 """Implements decorator.""" 129 130 @functools.wraps(fun) 131 def wrapper_stage(*args, **kwargs) -> 'Stage': 132 """Decorator helper function.""" 133 s = fun(*args, **kwargs) 134 s.skipped = True 135 s.assume_outputs_exist = assume_outputs_exist 136 return s 137 138 return wrapper_stage 139 140 if _fun is None: 141 return decorator_stage 142 return decorator_stage(_fun)
Decorator on top of @stage
that sets the self.skipped
field to True.
By default, expected outputs of a skipped stage will be checked,
unless assume_outputs_exist
is True.
@skip @stage class MyStage1(SequencingGroupStage): ...
@skip @stage(assume_outputs_exist=True) class MyStage2(SequencingGroupStage): ...
165class Workflow: 166 """ 167 Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. 168 Responsible for orchestrating stages. 169 """ 170 171 def __init__( 172 self, 173 stages: list['StageDecorator'] | None = None, 174 dry_run: bool | None = None, 175 ): 176 if _workflow is not None: 177 raise ValueError( 178 'Workflow already initialised. Use get_workflow() to get the instance', 179 ) 180 181 self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run') 182 self.show_workflow = get_config()['workflow'].get('show_workflow', False) 183 self.access_level = get_config()['workflow'].get('access_level', 'test') 184 185 # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset? 186 analysis_dataset = get_config(True)['workflow']['dataset'] 187 name = get_config()['workflow'].get('name', analysis_dataset) 188 description = get_config()['workflow'].get('description', name) 189 self.name = slugify(name) 190 191 self._output_version: str | None = None 192 if output_version := get_config()['workflow'].get('output_version'): 193 self._output_version = slugify(output_version) 194 195 self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp() 196 197 # Description 198 if self._output_version: 199 description += f': output_version={self._output_version}' 200 description += f': run_timestamp={self.run_timestamp}' 201 if sequencing_type := get_config()['workflow'].get('sequencing_type'): 202 description += f' [{sequencing_type}]' 203 if not self.dry_run: 204 if ds_set := set(d.name for d in get_multicohort().get_datasets()): 205 description += ' ' + ', '.join(sorted(ds_set)) 206 reset_batch() 207 get_batch().name = description 208 209 self.status_reporter = None 210 if get_config()['workflow'].get('status_reporter') == 'metamist': 211 self.status_reporter = MetamistStatusReporter() 212 self._stages: list[StageDecorator] | None = stages 213 self.queued_stages: list[Stage] = [] 214 215 @property 216 def output_version(self) -> str: 217 return self._output_version or get_multicohort().get_alignment_inputs_hash() 218 219 @property 220 def analysis_prefix(self) -> Path: 221 return self._prefix(category='analysis') 222 223 @property 224 def tmp_prefix(self) -> Path: 225 return self._prefix(category='tmp') 226 227 @property 228 def web_prefix(self) -> Path: 229 return self._prefix(category='web') 230 231 @property 232 def prefix(self) -> Path: 233 return self._prefix() 234 235 def _prefix(self, category: str | None = None) -> Path: 236 """ 237 Prepare a unique path for the workflow with this name and this input data. 238 """ 239 return get_multicohort().analysis_dataset.prefix(category=category) / self.name / self.output_version 240 241 def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path: 242 """ 243 Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method 244 Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID 245 e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123" 246 247 Args: 248 cohort (Cohort): we pull the analysis dataset and id from this Cohort 249 category (str | None): sub-bucket for this project 250 251 Returns: 252 Path 253 """ 254 return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.id 255 256 def run( 257 self, 258 stages: list['StageDecorator'] | None = None, 259 wait: bool | None = False, 260 ): 261 """ 262 Resolve stages, add and submit Hail Batch jobs. 263 When `run_all_implicit_stages` is set, all required stages that were not defined 264 explicitly would still be executed. 265 """ 266 _stages = stages or self._stages 267 if not _stages: 268 raise WorkflowError('No stages added') 269 self.set_stages(_stages) 270 271 if not self.dry_run: 272 get_batch().run(wait=wait) 273 274 @staticmethod 275 def _process_first_last_stages( 276 stages: list['Stage'], 277 graph: nx.DiGraph, 278 first_stages: list[str], 279 last_stages: list[str], 280 ): 281 """ 282 Applying first_stages and last_stages config options. Would skip all stages 283 before first_stages, and all stages after last_stages (i.e. descendants and 284 ancestors on the stages DAG.) 285 """ 286 stages_d = {s.name: s for s in stages} 287 stage_names = list(stg.name for stg in stages) 288 lower_names = {s.lower() for s in stage_names} 289 290 for param, _stage_list in [ 291 ('first_stages', first_stages), 292 ('last_stages', last_stages), 293 ]: 294 for _s_name in _stage_list: 295 if _s_name.lower() not in lower_names: 296 raise WorkflowError( 297 f'Value in workflow/{param} "{_s_name}" must be a stage name ' 298 f"or a subset of stages from the available list: " 299 f'{", ".join(stage_names)}', 300 ) 301 302 if not (last_stages or first_stages): 303 return 304 305 # E.g. if our last_stages is CramQc, MtToEs would still run because it's in 306 # a different branch. So we want to collect all stages after first_stages 307 # and before last_stages in their respective branches, and mark as skipped 308 # everything in other branches. 309 first_stages_keeps: list[str] = first_stages[:] 310 last_stages_keeps: list[str] = last_stages[:] 311 312 for fs in first_stages: 313 for descendant in nx.descendants(graph, fs): 314 if not stages_d[descendant].skipped: 315 LOGGER.info( 316 f'Skipping stage {descendant} (precedes {fs} listed in first_stages)', 317 ) 318 stages_d[descendant].skipped = True 319 for grand_descendant in nx.descendants(graph, descendant): 320 if not stages_d[grand_descendant].assume_outputs_exist: 321 LOGGER.info( 322 f'Not checking expected outputs of not immediately ' 323 f'required stage {grand_descendant} (< {descendant} < {fs})', 324 ) 325 stages_d[grand_descendant].assume_outputs_exist = True 326 327 for ancestor in nx.ancestors(graph, fs): 328 first_stages_keeps.append(ancestor) 329 330 for ls in last_stages: 331 # ancestors of this last_stage 332 ancestors = nx.ancestors(graph, ls) 333 if any(anc in last_stages for anc in ancestors): 334 # a downstream stage is also in last_stages, so this is not yet 335 # a "real" last stage that we want to run 336 continue 337 for ancestor in ancestors: 338 if stages_d[ancestor].skipped: 339 continue # already skipped 340 LOGGER.info(f'Skipping stage {ancestor} (after last {ls})') 341 stages_d[ancestor].skipped = True 342 stages_d[ancestor].assume_outputs_exist = True 343 344 for ancestor in nx.descendants(graph, ls): 345 last_stages_keeps.append(ancestor) 346 347 for _stage in stages: 348 if _stage.name not in last_stages_keeps + first_stages_keeps: 349 _stage.skipped = True 350 _stage.assume_outputs_exist = True 351 352 for stage in stages: 353 if stage.skipped: 354 graph.nodes[stage.name]['skipped'] = True 355 356 @staticmethod 357 def _process_only_stages( 358 stages: list['Stage'], 359 graph: nx.DiGraph, 360 only_stages: list[str], 361 ): 362 if not only_stages: 363 return 364 365 stages_d = {s.name: s for s in stages} 366 stage_names = list(stg.name for stg in stages) 367 lower_names = {s.lower() for s in stage_names} 368 369 for s_name in only_stages: 370 if s_name.lower() not in lower_names: 371 raise WorkflowError( 372 f'Value in workflow/only_stages "{s_name}" must be a stage ' 373 f"name or a subset of stages from the available list: " 374 f'{", ".join(stage_names)}', 375 ) 376 377 # We want to run stages only appearing in only_stages, and check outputs of 378 # imediate predecessor stages, but skip everything else. 379 required_stages: set[str] = set() 380 for os in only_stages: 381 rs = nx.descendants_at_distance(graph, os, 1) 382 required_stages |= set(rs) 383 384 for stage in stages: 385 # Skip stage not in only_stages, and assume outputs exist... 386 if stage.name not in only_stages: 387 stage.skipped = True 388 stage.assume_outputs_exist = True 389 390 # ...unless stage is directly required by any stage in only_stages 391 for stage_name in required_stages: 392 stages_d[stage_name].assume_outputs_exist = False 393 394 def set_stages( 395 self, 396 requested_stages: list['StageDecorator'], 397 ): 398 """ 399 Iterate over stages and call their queue_for_cohort(cohort) methods; 400 through that, creates all Hail Batch jobs through Stage.queue_jobs(). 401 """ 402 # TOML options to configure stages: 403 skip_stages = get_config()['workflow'].get('skip_stages', []) 404 only_stages = get_config()['workflow'].get('only_stages', []) 405 first_stages = get_config()['workflow'].get('first_stages', []) 406 last_stages = get_config()['workflow'].get('last_stages', []) 407 408 # Only allow one of only_stages or first_stages/last_stages as they seem 409 # to be mutually exclusive. 410 if only_stages and (first_stages or last_stages or skip_stages): 411 raise WorkflowError( 412 "Workflow config parameter 'only_stages' is incompatible with " 413 + "'first_stages', 'last_stages' and/or 'skip_stages'", 414 ) 415 416 LOGGER.info( 417 f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}', 418 ) 419 LOGGER.info('Stages additional configuration:') 420 LOGGER.info(f' workflow/skip_stages: {skip_stages}') 421 LOGGER.info(f' workflow/only_stages: {only_stages}') 422 LOGGER.info(f' workflow/first_stages: {first_stages}') 423 LOGGER.info(f' workflow/last_stages: {last_stages}') 424 425 # Round 1: initialising stage objects. 426 _stages_d: dict[str, Stage] = {} 427 for cls in requested_stages: 428 if cls.__name__ in _stages_d: 429 continue 430 _stages_d[cls.__name__] = cls() 431 432 # Round 2: depth search to find implicit stages. 433 depth = 0 434 while True: # might require few iterations to resolve dependencies recursively 435 depth += 1 436 newly_implicitly_added_d = dict() 437 for stg in _stages_d.values(): 438 if stg.name in skip_stages: 439 stg.skipped = True 440 continue # not searching deeper 441 442 if only_stages and stg.name not in only_stages: 443 stg.skipped = True 444 445 # Iterate dependencies: 446 for reqcls in stg.required_stages_classes: 447 if reqcls.__name__ in _stages_d: # already added 448 continue 449 # Initialising and adding as explicit. 450 reqstg = reqcls() 451 newly_implicitly_added_d[reqstg.name] = reqstg 452 453 if newly_implicitly_added_d: 454 LOGGER.info( 455 f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}', 456 ) 457 _stages_d |= newly_implicitly_added_d 458 else: 459 # No new implicit stages added, so can stop the depth-search here 460 break 461 462 # Round 3: set "stage.required_stages" fields to each stage. 463 for stg in _stages_d.values(): 464 stg.required_stages = [ 465 _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d 466 ] 467 468 # Round 4: determining order of execution. 469 dag_node2nodes = dict() # building a DAG 470 for stg in _stages_d.values(): 471 dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) 472 dag = nx.DiGraph(dag_node2nodes) 473 474 try: 475 stage_names = list(reversed(list(nx.topological_sort(dag)))) 476 except nx.NetworkXUnfeasible: 477 LOGGER.error('Circular dependencies found between stages') 478 raise 479 480 LOGGER.info(f'Stages in order of execution:\n{stage_names}') 481 stages = [_stages_d[name] for name in stage_names] 482 483 # Set order attribute to stages 484 nx.set_node_attributes(dag, {s.name: num for num, s in enumerate(stages)}, name='order') 485 486 # Update dag with the skipped attribute so it can be updated in self._process_first_last_stages 487 nx.set_node_attributes(dag, False, name='skipped') 488 489 # Round 5: applying workflow options first_stages and last_stages. 490 if first_stages or last_stages: 491 LOGGER.info('Applying workflow/first_stages and workflow/last_stages') 492 self._process_first_last_stages(stages, dag, first_stages, last_stages) 493 elif only_stages: 494 LOGGER.info('Applying workflow/only_stages') 495 self._process_only_stages(stages, dag, only_stages) 496 497 if not (final_set_of_stages := [s.name for s in stages if not s.skipped]): 498 raise WorkflowError('No stages to run') 499 500 LOGGER.info( 501 f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}', 502 ) 503 504 required_skipped_stages = [s for s in stages if s.skipped] 505 if required_skipped_stages: 506 LOGGER.info( 507 f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}', 508 ) 509 510 # Round 6: actually adding jobs from the stages. 511 if not self.dry_run: 512 inputs = get_multicohort() # Would communicate with metamist. 513 for i, stg in enumerate(stages): 514 LOGGER.info('*' * 60) 515 LOGGER.info(f'Stage #{i + 1}: {stg}') 516 # pipeline setup is now done in MultiCohort only 517 # the legacy version (input_datasets) is still supported 518 # that will create a MultiCohort with a single Cohort 519 if isinstance(inputs, MultiCohort): 520 stg.output_by_target = stg.queue_for_multicohort(inputs) 521 else: 522 raise WorkflowError(f'Unsupported input type: {inputs}') 523 if errors := self._process_stage_errors(stg.output_by_target): 524 raise WorkflowError( 525 f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors), 526 ) 527 528 LOGGER.info('') 529 530 else: 531 self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped] 532 LOGGER.info(f'Queued stages: {self.queued_stages}') 533 534 # Round 7: show the workflow 535 536 def format_meta(attr: list): 537 return {s: s in attr for s in dag.nodes} 538 539 # First add remaining metadata 540 nx.set_node_attributes(dag, format_meta(skip_stages), name='skip_stages') 541 nx.set_node_attributes(dag, format_meta(only_stages), name='only_stages') 542 nx.set_node_attributes(dag, format_meta(first_stages), name='first_stages') 543 nx.set_node_attributes(dag, format_meta(last_stages), name='last_stages') 544 545 if self.show_workflow: 546 gp = GraphPlot(dag, title='Full Workflow Graph') 547 548 # Removed skipped steps for simple graph 549 all_nodes = list(dag.nodes) 550 _ = [dag.remove_node(n) for n in all_nodes if dag.nodes[n]['skipped']] 551 gp2 = GraphPlot(dag, title='Sub-Workflow Graph') 552 553 fig = gp + gp2 554 555 # Show the figure 556 fig.show() 557 558 # If we have a web bucket path 559 try: 560 if web_prefix := self.web_prefix: 561 html_path = web_prefix / f'{self.name}_workflow.html' 562 if str(html_path).startswith('gs:/'): 563 html_file = pio.to_html(fig, full_html=True) 564 _, file_path = write_to_gcs_bucket(html_file, html_path) 565 url = URL_BASENAME.format(access_level=self.access_level, name=self.name) + str(file_path) 566 567 LOGGER.info(f'Link to the graph: {url}') 568 else: 569 pio.write_html(fig, file=str(html_path), auto_open=False) 570 571 LOGGER.info(f'Workflow graph saved to {html_path}') 572 573 except ConnectionError as e: 574 LOGGER.error(f'Failed to save workflow graph: {e}') 575 576 @staticmethod 577 def _process_stage_errors( 578 output_by_target: dict[str, Union['StageOutput', None]], 579 ) -> list[str]: 580 targets_by_error = defaultdict(list) 581 for target, output in output_by_target.items(): 582 if output and output.error_msg: 583 targets_by_error[output.error_msg].append(target) 584 return [f'{error}: {", ".join(target_ids)}' for error, target_ids in targets_by_error.items()]
Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. Responsible for orchestrating stages.
171 def __init__( 172 self, 173 stages: list['StageDecorator'] | None = None, 174 dry_run: bool | None = None, 175 ): 176 if _workflow is not None: 177 raise ValueError( 178 'Workflow already initialised. Use get_workflow() to get the instance', 179 ) 180 181 self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run') 182 self.show_workflow = get_config()['workflow'].get('show_workflow', False) 183 self.access_level = get_config()['workflow'].get('access_level', 'test') 184 185 # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset? 186 analysis_dataset = get_config(True)['workflow']['dataset'] 187 name = get_config()['workflow'].get('name', analysis_dataset) 188 description = get_config()['workflow'].get('description', name) 189 self.name = slugify(name) 190 191 self._output_version: str | None = None 192 if output_version := get_config()['workflow'].get('output_version'): 193 self._output_version = slugify(output_version) 194 195 self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp() 196 197 # Description 198 if self._output_version: 199 description += f': output_version={self._output_version}' 200 description += f': run_timestamp={self.run_timestamp}' 201 if sequencing_type := get_config()['workflow'].get('sequencing_type'): 202 description += f' [{sequencing_type}]' 203 if not self.dry_run: 204 if ds_set := set(d.name for d in get_multicohort().get_datasets()): 205 description += ' ' + ', '.join(sorted(ds_set)) 206 reset_batch() 207 get_batch().name = description 208 209 self.status_reporter = None 210 if get_config()['workflow'].get('status_reporter') == 'metamist': 211 self.status_reporter = MetamistStatusReporter() 212 self._stages: list[StageDecorator] | None = stages 213 self.queued_stages: list[Stage] = []
241 def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path: 242 """ 243 Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method 244 Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID 245 e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123" 246 247 Args: 248 cohort (Cohort): we pull the analysis dataset and id from this Cohort 249 category (str | None): sub-bucket for this project 250 251 Returns: 252 Path 253 """ 254 return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.id
Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"
Args: cohort (Cohort): we pull the analysis dataset and id from this Cohort category (str | None): sub-bucket for this project
Returns: Path
256 def run( 257 self, 258 stages: list['StageDecorator'] | None = None, 259 wait: bool | None = False, 260 ): 261 """ 262 Resolve stages, add and submit Hail Batch jobs. 263 When `run_all_implicit_stages` is set, all required stages that were not defined 264 explicitly would still be executed. 265 """ 266 _stages = stages or self._stages 267 if not _stages: 268 raise WorkflowError('No stages added') 269 self.set_stages(_stages) 270 271 if not self.dry_run: 272 get_batch().run(wait=wait)
Resolve stages, add and submit Hail Batch jobs.
When run_all_implicit_stages
is set, all required stages that were not defined
explicitly would still be executed.
394 def set_stages( 395 self, 396 requested_stages: list['StageDecorator'], 397 ): 398 """ 399 Iterate over stages and call their queue_for_cohort(cohort) methods; 400 through that, creates all Hail Batch jobs through Stage.queue_jobs(). 401 """ 402 # TOML options to configure stages: 403 skip_stages = get_config()['workflow'].get('skip_stages', []) 404 only_stages = get_config()['workflow'].get('only_stages', []) 405 first_stages = get_config()['workflow'].get('first_stages', []) 406 last_stages = get_config()['workflow'].get('last_stages', []) 407 408 # Only allow one of only_stages or first_stages/last_stages as they seem 409 # to be mutually exclusive. 410 if only_stages and (first_stages or last_stages or skip_stages): 411 raise WorkflowError( 412 "Workflow config parameter 'only_stages' is incompatible with " 413 + "'first_stages', 'last_stages' and/or 'skip_stages'", 414 ) 415 416 LOGGER.info( 417 f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}', 418 ) 419 LOGGER.info('Stages additional configuration:') 420 LOGGER.info(f' workflow/skip_stages: {skip_stages}') 421 LOGGER.info(f' workflow/only_stages: {only_stages}') 422 LOGGER.info(f' workflow/first_stages: {first_stages}') 423 LOGGER.info(f' workflow/last_stages: {last_stages}') 424 425 # Round 1: initialising stage objects. 426 _stages_d: dict[str, Stage] = {} 427 for cls in requested_stages: 428 if cls.__name__ in _stages_d: 429 continue 430 _stages_d[cls.__name__] = cls() 431 432 # Round 2: depth search to find implicit stages. 433 depth = 0 434 while True: # might require few iterations to resolve dependencies recursively 435 depth += 1 436 newly_implicitly_added_d = dict() 437 for stg in _stages_d.values(): 438 if stg.name in skip_stages: 439 stg.skipped = True 440 continue # not searching deeper 441 442 if only_stages and stg.name not in only_stages: 443 stg.skipped = True 444 445 # Iterate dependencies: 446 for reqcls in stg.required_stages_classes: 447 if reqcls.__name__ in _stages_d: # already added 448 continue 449 # Initialising and adding as explicit. 450 reqstg = reqcls() 451 newly_implicitly_added_d[reqstg.name] = reqstg 452 453 if newly_implicitly_added_d: 454 LOGGER.info( 455 f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}', 456 ) 457 _stages_d |= newly_implicitly_added_d 458 else: 459 # No new implicit stages added, so can stop the depth-search here 460 break 461 462 # Round 3: set "stage.required_stages" fields to each stage. 463 for stg in _stages_d.values(): 464 stg.required_stages = [ 465 _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d 466 ] 467 468 # Round 4: determining order of execution. 469 dag_node2nodes = dict() # building a DAG 470 for stg in _stages_d.values(): 471 dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages) 472 dag = nx.DiGraph(dag_node2nodes) 473 474 try: 475 stage_names = list(reversed(list(nx.topological_sort(dag)))) 476 except nx.NetworkXUnfeasible: 477 LOGGER.error('Circular dependencies found between stages') 478 raise 479 480 LOGGER.info(f'Stages in order of execution:\n{stage_names}') 481 stages = [_stages_d[name] for name in stage_names] 482 483 # Set order attribute to stages 484 nx.set_node_attributes(dag, {s.name: num for num, s in enumerate(stages)}, name='order') 485 486 # Update dag with the skipped attribute so it can be updated in self._process_first_last_stages 487 nx.set_node_attributes(dag, False, name='skipped') 488 489 # Round 5: applying workflow options first_stages and last_stages. 490 if first_stages or last_stages: 491 LOGGER.info('Applying workflow/first_stages and workflow/last_stages') 492 self._process_first_last_stages(stages, dag, first_stages, last_stages) 493 elif only_stages: 494 LOGGER.info('Applying workflow/only_stages') 495 self._process_only_stages(stages, dag, only_stages) 496 497 if not (final_set_of_stages := [s.name for s in stages if not s.skipped]): 498 raise WorkflowError('No stages to run') 499 500 LOGGER.info( 501 f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}', 502 ) 503 504 required_skipped_stages = [s for s in stages if s.skipped] 505 if required_skipped_stages: 506 LOGGER.info( 507 f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}', 508 ) 509 510 # Round 6: actually adding jobs from the stages. 511 if not self.dry_run: 512 inputs = get_multicohort() # Would communicate with metamist. 513 for i, stg in enumerate(stages): 514 LOGGER.info('*' * 60) 515 LOGGER.info(f'Stage #{i + 1}: {stg}') 516 # pipeline setup is now done in MultiCohort only 517 # the legacy version (input_datasets) is still supported 518 # that will create a MultiCohort with a single Cohort 519 if isinstance(inputs, MultiCohort): 520 stg.output_by_target = stg.queue_for_multicohort(inputs) 521 else: 522 raise WorkflowError(f'Unsupported input type: {inputs}') 523 if errors := self._process_stage_errors(stg.output_by_target): 524 raise WorkflowError( 525 f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors), 526 ) 527 528 LOGGER.info('') 529 530 else: 531 self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped] 532 LOGGER.info(f'Queued stages: {self.queued_stages}') 533 534 # Round 7: show the workflow 535 536 def format_meta(attr: list): 537 return {s: s in attr for s in dag.nodes} 538 539 # First add remaining metadata 540 nx.set_node_attributes(dag, format_meta(skip_stages), name='skip_stages') 541 nx.set_node_attributes(dag, format_meta(only_stages), name='only_stages') 542 nx.set_node_attributes(dag, format_meta(first_stages), name='first_stages') 543 nx.set_node_attributes(dag, format_meta(last_stages), name='last_stages') 544 545 if self.show_workflow: 546 gp = GraphPlot(dag, title='Full Workflow Graph') 547 548 # Removed skipped steps for simple graph 549 all_nodes = list(dag.nodes) 550 _ = [dag.remove_node(n) for n in all_nodes if dag.nodes[n]['skipped']] 551 gp2 = GraphPlot(dag, title='Sub-Workflow Graph') 552 553 fig = gp + gp2 554 555 # Show the figure 556 fig.show() 557 558 # If we have a web bucket path 559 try: 560 if web_prefix := self.web_prefix: 561 html_path = web_prefix / f'{self.name}_workflow.html' 562 if str(html_path).startswith('gs:/'): 563 html_file = pio.to_html(fig, full_html=True) 564 _, file_path = write_to_gcs_bucket(html_file, html_path) 565 url = URL_BASENAME.format(access_level=self.access_level, name=self.name) + str(file_path) 566 567 LOGGER.info(f'Link to the graph: {url}') 568 else: 569 pio.write_html(fig, file=str(html_path), auto_open=False) 570 571 LOGGER.info(f'Workflow graph saved to {html_path}') 572 573 except ConnectionError as e: 574 LOGGER.error(f'Failed to save workflow graph: {e}')
Iterate over stages and call their queue_for_cohort(cohort) methods; through that, creates all Hail Batch jobs through Stage.queue_jobs().