cpg_flow.workflow

Provides a Workflow class and a @stage decorator that allow to define workflows in a declarative fashion.

A Stage object is responsible for creating Hail Batch jobs and declaring outputs (files or metamist analysis objects) that are expected to be produced. Each stage acts on a Target, which can be of the following:

* SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample)
* Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care)
* Cohort - a stratification of SGs in this analysis by Metamist CustomCohort
* MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort

A Workflow object plugs stages together by resolving dependencies between different levels accordingly. Stages are defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in main.py, which provides a way to choose a workflow using a CLI argument.

  1"""
  2Provides a `Workflow` class and a `@stage` decorator that allow to define workflows
  3in a declarative fashion.
  4
  5A `Stage` object is responsible for creating Hail Batch jobs and declaring outputs
  6(files or metamist analysis objects) that are expected to be produced. Each stage
  7acts on a `Target`, which can be of the following:
  8
  9    * SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample)
 10    * Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care)
 11    * Cohort - a stratification of SGs in this analysis by Metamist CustomCohort
 12    * MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort
 13
 14A `Workflow` object plugs stages together by resolving dependencies between different levels accordingly. Stages are
 15defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in
 16main.py, which provides a way to choose a workflow using a CLI argument.
 17"""
 18
 19import functools
 20from collections import defaultdict
 21from collections.abc import Callable
 22from enum import Enum
 23from typing import TYPE_CHECKING, Optional, Union
 24
 25import networkx as nx
 26import plotly.io as pio
 27
 28from cpg_flow.inputs import get_multicohort
 29from cpg_flow.show_workflow.graph import GraphPlot
 30from cpg_flow.status import MetamistStatusReporter
 31from cpg_flow.targets import Cohort, MultiCohort
 32from cpg_flow.utils import get_logger, slugify, timestamp, write_to_gcs_bucket
 33from cpg_utils import Path
 34from cpg_utils.config import get_config
 35from cpg_utils.hail_batch import get_batch, reset_batch
 36
 37LOGGER = get_logger(__name__)
 38URL_BASENAME = 'https://{access_level}-web.populationgenomics.org.au/{name}/'
 39
 40if TYPE_CHECKING:
 41    from cpg_flow.stage import Stage, StageDecorator, StageOutput
 42
 43
 44def path_walk(expected, collected: set | None = None) -> set[Path]:
 45    """
 46    recursive walk of expected_out
 47    if the object is iterable, walk it
 48    this gets around the issue with nested lists and dicts
 49    mainly around the use of Array outputs from Cromwell
 50
 51    Args:
 52        expected (Any): any type of object containing Paths
 53        collected (set): all collected paths so far
 54
 55    Returns:
 56        a set of all collected Path nodes
 57
 58    Examples:
 59
 60    >>> path_walk({'a': {'b': {'c': Path('d')}}})
 61    {Path('d')}
 62    >>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}})
 63    {Path('d'), Path('e')}
 64    >>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}})
 65    {Path('b'), Path('g')}
 66    """
 67    if collected is None:
 68        collected = set()
 69
 70    if expected is None:
 71        return collected
 72    if isinstance(expected, dict):
 73        for value in expected.values():
 74            collected.update(path_walk(value, collected))
 75    if isinstance(expected, list | set):
 76        for value in expected:
 77            collected.update(path_walk(value, collected))
 78    if isinstance(expected, str):
 79        return collected
 80    if isinstance(expected, Path):
 81        if expected in collected:
 82            raise ValueError(f'Duplicate path {expected} in expected_out')
 83        collected.add(expected)
 84    return collected
 85
 86
 87class WorkflowError(Exception):
 88    """
 89    Error raised by workflow and stage implementation.
 90    """
 91
 92
 93class Action(Enum):
 94    """
 95    Indicates what a stage should do with a specific target.
 96    """
 97
 98    QUEUE = 1
 99    SKIP = 2
100    REUSE = 3
101
102
103# noinspection PyUnusedLocal
104def skip(
105    _fun: Optional['StageDecorator'] = None,
106    *,
107    reason: str | None = None,
108    assume_outputs_exist: bool = False,
109) -> Union['StageDecorator', Callable[..., 'StageDecorator']]:
110    """
111    Decorator on top of `@stage` that sets the `self.skipped` field to True.
112    By default, expected outputs of a skipped stage will be checked,
113    unless `assume_outputs_exist` is True.
114
115    @skip
116    @stage
117    class MyStage1(SequencingGroupStage):
118        ...
119
120    @skip
121    @stage(assume_outputs_exist=True)
122    class MyStage2(SequencingGroupStage):
123        ...
124    """
125
126    def decorator_stage(fun) -> 'StageDecorator':
127        """Implements decorator."""
128
129        @functools.wraps(fun)
130        def wrapper_stage(*args, **kwargs) -> 'Stage':
131            """Decorator helper function."""
132            s = fun(*args, **kwargs)
133            s.skipped = True
134            s.assume_outputs_exist = assume_outputs_exist
135            return s
136
137        return wrapper_stage
138
139    if _fun is None:
140        return decorator_stage
141    return decorator_stage(_fun)
142
143
144_workflow: Optional['Workflow'] = None
145
146
147def get_workflow(dry_run: bool = False) -> 'Workflow':
148    global _workflow
149    if _workflow is None:
150        _workflow = Workflow(dry_run=dry_run)
151    return _workflow
152
153
154def run_workflow(
155    stages: list['StageDecorator'] | None = None,
156    wait: bool | None = False,
157    dry_run: bool = False,
158) -> 'Workflow':
159    wfl = get_workflow(dry_run=dry_run)
160    wfl.run(stages=stages, wait=wait)
161    return wfl
162
163
164class Workflow:
165    """
166    Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups.
167    Responsible for orchestrating stages.
168    """
169
170    def __init__(
171        self,
172        stages: list['StageDecorator'] | None = None,
173        dry_run: bool | None = None,
174    ):
175        if _workflow is not None:
176            raise ValueError(
177                'Workflow already initialised. Use get_workflow() to get the instance',
178            )
179
180        self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run')
181        self.show_workflow = get_config()['workflow'].get('show_workflow', False)
182        self.access_level = get_config()['workflow'].get('access_level', 'test')
183
184        # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset?
185        analysis_dataset = get_config(True)['workflow']['dataset']
186        name = get_config()['workflow'].get('name', analysis_dataset)
187        description = get_config()['workflow'].get('description', name)
188        self.name = slugify(name)
189
190        self._output_version: str | None = None
191        if output_version := get_config()['workflow'].get('output_version'):
192            self._output_version = slugify(output_version)
193
194        self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp()
195
196        # Description
197        if self._output_version:
198            description += f': output_version={self._output_version}'
199        description += f': run_timestamp={self.run_timestamp}'
200        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
201            description += f' [{sequencing_type}]'
202        if not self.dry_run:
203            if ds_set := set(d.name for d in get_multicohort().get_datasets()):
204                description += ' ' + ', '.join(sorted(ds_set))
205            reset_batch()
206            get_batch().name = description
207
208        self.status_reporter = None
209        if get_config()['workflow'].get('status_reporter') == 'metamist':
210            self.status_reporter = MetamistStatusReporter()
211        self._stages: list[StageDecorator] | None = stages
212        self.queued_stages: list[Stage] = []
213
214    @property
215    def output_version(self) -> str:
216        return self._output_version or get_multicohort().get_alignment_inputs_hash()
217
218    @property
219    def analysis_prefix(self) -> Path:
220        return self._prefix(category='analysis')
221
222    @property
223    def tmp_prefix(self) -> Path:
224        return self._prefix(category='tmp')
225
226    @property
227    def web_prefix(self) -> Path:
228        return self._prefix(category='web')
229
230    @property
231    def prefix(self) -> Path:
232        return self._prefix()
233
234    def _prefix(self, category: str | None = None) -> Path:
235        """
236        Prepare a unique path for the workflow with this name and this input data.
237        """
238        return get_multicohort().analysis_dataset.prefix(category=category) / self.name / self.output_version
239
240    def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path:
241        """
242        Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method
243        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID
244        e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"
245
246        Args:
247            cohort (Cohort): we pull the analysis dataset and id from this Cohort
248            category (str | None): sub-bucket for this project
249
250        Returns:
251            Path
252        """
253        return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.id
254
255    def run(
256        self,
257        stages: list['StageDecorator'] | None = None,
258        wait: bool | None = False,
259    ):
260        """
261        Resolve stages, add and submit Hail Batch jobs.
262        When `run_all_implicit_stages` is set, all required stages that were not defined
263        explicitly would still be executed.
264        """
265        _stages = stages or self._stages
266        if not _stages:
267            raise WorkflowError('No stages added')
268        self.set_stages(_stages)
269
270        if not self.dry_run:
271            get_batch().run(wait=wait)
272
273    @staticmethod
274    def _process_first_last_stages(
275        stages: list['Stage'],
276        graph: nx.DiGraph,
277        first_stages: list[str],
278        last_stages: list[str],
279    ):
280        """
281        Applying first_stages and last_stages config options. Would skip all stages
282        before first_stages, and all stages after last_stages (i.e. descendants and
283        ancestors on the stages DAG.)
284        """
285        stages_d = {s.name: s for s in stages}
286        stage_names = list(stg.name for stg in stages)
287        lower_names = {s.lower() for s in stage_names}
288
289        for param, _stage_list in [
290            ('first_stages', first_stages),
291            ('last_stages', last_stages),
292        ]:
293            for _s_name in _stage_list:
294                if _s_name.lower() not in lower_names:
295                    raise WorkflowError(
296                        f'Value in workflow/{param} "{_s_name}" must be a stage name '
297                        f"or a subset of stages from the available list: "
298                        f'{", ".join(stage_names)}',
299                    )
300
301        if not (last_stages or first_stages):
302            return
303
304        # E.g. if our last_stages is CramQc, MtToEs would still run because it's in
305        # a different branch. So we want to collect all stages after first_stages
306        # and before last_stages in their respective branches, and mark as skipped
307        # everything in other branches.
308        first_stages_keeps: list[str] = first_stages[:]
309        last_stages_keeps: list[str] = last_stages[:]
310
311        for fs in first_stages:
312            for descendant in nx.descendants(graph, fs):
313                if not stages_d[descendant].skipped:
314                    LOGGER.info(
315                        f'Skipping stage {descendant} (precedes {fs} listed in first_stages)',
316                    )
317                    stages_d[descendant].skipped = True
318                for grand_descendant in nx.descendants(graph, descendant):
319                    if not stages_d[grand_descendant].assume_outputs_exist:
320                        LOGGER.info(
321                            f'Not checking expected outputs of not immediately '
322                            f'required stage {grand_descendant} (< {descendant} < {fs})',
323                        )
324                        stages_d[grand_descendant].assume_outputs_exist = True
325
326            for ancestor in nx.ancestors(graph, fs):
327                first_stages_keeps.append(ancestor)
328
329        for ls in last_stages:
330            # ancestors of this last_stage
331            ancestors = nx.ancestors(graph, ls)
332            if any(anc in last_stages for anc in ancestors):
333                # a downstream stage is also in last_stages, so this is not yet
334                # a "real" last stage that we want to run
335                continue
336            for ancestor in ancestors:
337                if stages_d[ancestor].skipped:
338                    continue  # already skipped
339                LOGGER.info(f'Skipping stage {ancestor} (after last {ls})')
340                stages_d[ancestor].skipped = True
341                stages_d[ancestor].assume_outputs_exist = True
342
343            for ancestor in nx.descendants(graph, ls):
344                last_stages_keeps.append(ancestor)
345
346        for _stage in stages:
347            if _stage.name not in last_stages_keeps + first_stages_keeps:
348                _stage.skipped = True
349                _stage.assume_outputs_exist = True
350
351        for stage in stages:
352            if stage.skipped:
353                graph.nodes[stage.name]['skipped'] = True
354
355    @staticmethod
356    def _process_only_stages(
357        stages: list['Stage'],
358        graph: nx.DiGraph,
359        only_stages: list[str],
360    ):
361        if not only_stages:
362            return
363
364        stages_d = {s.name: s for s in stages}
365        stage_names = list(stg.name for stg in stages)
366        lower_names = {s.lower() for s in stage_names}
367
368        for s_name in only_stages:
369            if s_name.lower() not in lower_names:
370                raise WorkflowError(
371                    f'Value in workflow/only_stages "{s_name}" must be a stage '
372                    f"name or a subset of stages from the available list: "
373                    f'{", ".join(stage_names)}',
374                )
375
376        # We want to run stages only appearing in only_stages, and check outputs of
377        # imediate predecessor stages, but skip everything else.
378        required_stages: set[str] = set()
379        for os in only_stages:
380            rs = nx.descendants_at_distance(graph, os, 1)
381            required_stages |= set(rs)
382
383        for stage in stages:
384            # Skip stage not in only_stages, and assume outputs exist...
385            if stage.name not in only_stages:
386                stage.skipped = True
387                stage.assume_outputs_exist = True
388
389        # ...unless stage is directly required by any stage in only_stages
390        for stage_name in required_stages:
391            stages_d[stage_name].assume_outputs_exist = False
392
393    def set_stages(
394        self,
395        requested_stages: list['StageDecorator'],
396    ):
397        """
398        Iterate over stages and call their queue_for_cohort(cohort) methods;
399        through that, creates all Hail Batch jobs through Stage.queue_jobs().
400        """
401        # TOML options to configure stages:
402        skip_stages = get_config()['workflow'].get('skip_stages', [])
403        only_stages = get_config()['workflow'].get('only_stages', [])
404        first_stages = get_config()['workflow'].get('first_stages', [])
405        last_stages = get_config()['workflow'].get('last_stages', [])
406
407        # Only allow one of only_stages or first_stages/last_stages as they seem
408        # to be mutually exclusive.
409        if only_stages and (first_stages or last_stages or skip_stages):
410            raise WorkflowError(
411                "Workflow config parameter 'only_stages' is incompatible with "
412                + "'first_stages', 'last_stages' and/or 'skip_stages'",
413            )
414
415        LOGGER.info(
416            f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}',
417        )
418        LOGGER.info('Stages additional configuration:')
419        LOGGER.info(f'  workflow/skip_stages: {skip_stages}')
420        LOGGER.info(f'  workflow/only_stages: {only_stages}')
421        LOGGER.info(f'  workflow/first_stages: {first_stages}')
422        LOGGER.info(f'  workflow/last_stages: {last_stages}')
423
424        # Round 1: initialising stage objects.
425        _stages_d: dict[str, Stage] = {}
426        for cls in requested_stages:
427            if cls.__name__ in _stages_d:
428                continue
429            _stages_d[cls.__name__] = cls()
430
431        # Round 2: depth search to find implicit stages.
432        depth = 0
433        while True:  # might require few iterations to resolve dependencies recursively
434            depth += 1
435            newly_implicitly_added_d = dict()
436            for stg in _stages_d.values():
437                if stg.name in skip_stages:
438                    stg.skipped = True
439                    continue  # not searching deeper
440
441                if only_stages and stg.name not in only_stages:
442                    stg.skipped = True
443
444                # Iterate dependencies:
445                for reqcls in stg.required_stages_classes:
446                    if reqcls.__name__ in _stages_d:  # already added
447                        continue
448                    # Initialising and adding as explicit.
449                    reqstg = reqcls()
450                    newly_implicitly_added_d[reqstg.name] = reqstg
451
452            if newly_implicitly_added_d:
453                LOGGER.info(
454                    f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}',
455                )
456                _stages_d |= newly_implicitly_added_d
457            else:
458                # No new implicit stages added, so can stop the depth-search here
459                break
460
461        # Round 3: set "stage.required_stages" fields to each stage.
462        for stg in _stages_d.values():
463            stg.required_stages = [
464                _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d
465            ]
466
467        # Round 4: determining order of execution.
468        dag_node2nodes = dict()  # building a DAG
469        for stg in _stages_d.values():
470            dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages)
471        dag = nx.DiGraph(dag_node2nodes)
472
473        try:
474            stage_names = list(reversed(list(nx.topological_sort(dag))))
475        except nx.NetworkXUnfeasible:
476            LOGGER.error('Circular dependencies found between stages')
477            raise
478
479        LOGGER.info(f'Stages in order of execution:\n{stage_names}')
480        stages = [_stages_d[name] for name in stage_names]
481
482        # Set order attribute to stages
483        nx.set_node_attributes(dag, {s.name: num for num, s in enumerate(stages)}, name='order')
484
485        # Update dag with the skipped attribute so it can be updated in self._process_first_last_stages
486        nx.set_node_attributes(dag, False, name='skipped')
487
488        # Round 5: applying workflow options first_stages and last_stages.
489        if first_stages or last_stages:
490            LOGGER.info('Applying workflow/first_stages and workflow/last_stages')
491            self._process_first_last_stages(stages, dag, first_stages, last_stages)
492        elif only_stages:
493            LOGGER.info('Applying workflow/only_stages')
494            self._process_only_stages(stages, dag, only_stages)
495
496        if not (final_set_of_stages := [s.name for s in stages if not s.skipped]):
497            raise WorkflowError('No stages to run')
498
499        LOGGER.info(
500            f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}',
501        )
502
503        required_skipped_stages = [s for s in stages if s.skipped]
504        if required_skipped_stages:
505            LOGGER.info(
506                f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}',
507            )
508
509        # Round 6: actually adding jobs from the stages.
510        if not self.dry_run:
511            inputs = get_multicohort()  # Would communicate with metamist.
512            for i, stg in enumerate(stages):
513                LOGGER.info('*' * 60)
514                LOGGER.info(f'Stage #{i + 1}: {stg}')
515                # pipeline setup is now done in MultiCohort only
516                # the legacy version (input_datasets) is still supported
517                # that will create a MultiCohort with a single Cohort
518                if isinstance(inputs, MultiCohort):
519                    stg.output_by_target = stg.queue_for_multicohort(inputs)
520                else:
521                    raise WorkflowError(f'Unsupported input type: {inputs}')
522                if errors := self._process_stage_errors(stg.output_by_target):
523                    raise WorkflowError(
524                        f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors),
525                    )
526
527                LOGGER.info('')
528
529        else:
530            self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped]
531            LOGGER.info(f'Queued stages: {self.queued_stages}')
532
533        # Round 7: show the workflow
534
535        def format_meta(attr: list):
536            return {s: s in attr for s in dag.nodes}
537
538        # First add remaining metadata
539        nx.set_node_attributes(dag, format_meta(skip_stages), name='skip_stages')
540        nx.set_node_attributes(dag, format_meta(only_stages), name='only_stages')
541        nx.set_node_attributes(dag, format_meta(first_stages), name='first_stages')
542        nx.set_node_attributes(dag, format_meta(last_stages), name='last_stages')
543
544        if self.show_workflow:
545            gp = GraphPlot(dag, title='Full Workflow Graph')
546
547            # Removed skipped steps for simple graph
548            all_nodes = list(dag.nodes)
549            _ = [dag.remove_node(n) for n in all_nodes if dag.nodes[n]['skipped']]
550            gp2 = GraphPlot(dag, title='Sub-Workflow Graph')
551
552            fig = gp + gp2
553
554            # Show the figure
555            fig.show()
556
557            # If we have a web bucket path
558            try:
559                if web_prefix := self.web_prefix:
560                    html_path = web_prefix / f'{self.name}_workflow.html'
561                    if str(html_path).startswith('gs:/'):
562                        html_file = pio.to_html(fig, full_html=True)
563                        _, file_path = write_to_gcs_bucket(html_file, html_path)
564                        url = URL_BASENAME.format(access_level=self.access_level, name=self.name) + str(file_path)
565
566                        LOGGER.info(f'Link to the graph: {url}')
567                    else:
568                        pio.write_html(fig, file=str(html_path), auto_open=False)
569
570                    LOGGER.info(f'Workflow graph saved to {html_path}')
571
572            except ConnectionError as e:
573                LOGGER.error(f'Failed to save workflow graph: {e}')
574
575    @staticmethod
576    def _process_stage_errors(
577        output_by_target: dict[str, Union['StageOutput', None]],
578    ) -> list[str]:
579        targets_by_error = defaultdict(list)
580        for target, output in output_by_target.items():
581            if output and output.error_msg:
582                targets_by_error[output.error_msg].append(target)
583        return [f'{error}: {", ".join(target_ids)}' for error, target_ids in targets_by_error.items()]
LOGGER = <Logger cpg_flow.workflow (INFO)>
URL_BASENAME = 'https://{access_level}-web.populationgenomics.org.au/{name}/'
def path_walk( expected, collected: set | None = None) -> set[cloudpathlib.cloudpath.CloudPath | pathlib.Path]:
45def path_walk(expected, collected: set | None = None) -> set[Path]:
46    """
47    recursive walk of expected_out
48    if the object is iterable, walk it
49    this gets around the issue with nested lists and dicts
50    mainly around the use of Array outputs from Cromwell
51
52    Args:
53        expected (Any): any type of object containing Paths
54        collected (set): all collected paths so far
55
56    Returns:
57        a set of all collected Path nodes
58
59    Examples:
60
61    >>> path_walk({'a': {'b': {'c': Path('d')}}})
62    {Path('d')}
63    >>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}})
64    {Path('d'), Path('e')}
65    >>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}})
66    {Path('b'), Path('g')}
67    """
68    if collected is None:
69        collected = set()
70
71    if expected is None:
72        return collected
73    if isinstance(expected, dict):
74        for value in expected.values():
75            collected.update(path_walk(value, collected))
76    if isinstance(expected, list | set):
77        for value in expected:
78            collected.update(path_walk(value, collected))
79    if isinstance(expected, str):
80        return collected
81    if isinstance(expected, Path):
82        if expected in collected:
83            raise ValueError(f'Duplicate path {expected} in expected_out')
84        collected.add(expected)
85    return collected

recursive walk of expected_out if the object is iterable, walk it this gets around the issue with nested lists and dicts mainly around the use of Array outputs from Cromwell

Args: expected (Any): any type of object containing Paths collected (set): all collected paths so far

Returns: a set of all collected Path nodes

Examples:

>>> path_walk({'a': {'b': {'c': Path('d')}}})
{Path('d')}
>>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}})
{Path('d'), Path('e')}
>>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}})
{Path('b'), Path('g')}
class WorkflowError(builtins.Exception):
88class WorkflowError(Exception):
89    """
90    Error raised by workflow and stage implementation.
91    """

Error raised by workflow and stage implementation.

class Action(enum.Enum):
 94class Action(Enum):
 95    """
 96    Indicates what a stage should do with a specific target.
 97    """
 98
 99    QUEUE = 1
100    SKIP = 2
101    REUSE = 3

Indicates what a stage should do with a specific target.

QUEUE = <Action.QUEUE: 1>
SKIP = <Action.SKIP: 2>
REUSE = <Action.REUSE: 3>
def skip( _fun: Optional[Callable[..., cpg_flow.stage.Stage]] = None, *, reason: str | None = None, assume_outputs_exist: bool = False) -> Union[Callable[..., cpg_flow.stage.Stage], Callable[..., Callable[..., cpg_flow.stage.Stage]]]:
105def skip(
106    _fun: Optional['StageDecorator'] = None,
107    *,
108    reason: str | None = None,
109    assume_outputs_exist: bool = False,
110) -> Union['StageDecorator', Callable[..., 'StageDecorator']]:
111    """
112    Decorator on top of `@stage` that sets the `self.skipped` field to True.
113    By default, expected outputs of a skipped stage will be checked,
114    unless `assume_outputs_exist` is True.
115
116    @skip
117    @stage
118    class MyStage1(SequencingGroupStage):
119        ...
120
121    @skip
122    @stage(assume_outputs_exist=True)
123    class MyStage2(SequencingGroupStage):
124        ...
125    """
126
127    def decorator_stage(fun) -> 'StageDecorator':
128        """Implements decorator."""
129
130        @functools.wraps(fun)
131        def wrapper_stage(*args, **kwargs) -> 'Stage':
132            """Decorator helper function."""
133            s = fun(*args, **kwargs)
134            s.skipped = True
135            s.assume_outputs_exist = assume_outputs_exist
136            return s
137
138        return wrapper_stage
139
140    if _fun is None:
141        return decorator_stage
142    return decorator_stage(_fun)

Decorator on top of @stage that sets the self.skipped field to True. By default, expected outputs of a skipped stage will be checked, unless assume_outputs_exist is True.

@skip @stage class MyStage1(SequencingGroupStage): ...

@skip @stage(assume_outputs_exist=True) class MyStage2(SequencingGroupStage): ...

def get_workflow(dry_run: bool = False) -> Workflow:
148def get_workflow(dry_run: bool = False) -> 'Workflow':
149    global _workflow
150    if _workflow is None:
151        _workflow = Workflow(dry_run=dry_run)
152    return _workflow
def run_workflow( stages: list[Callable[..., cpg_flow.stage.Stage]] | None = None, wait: bool | None = False, dry_run: bool = False) -> Workflow:
155def run_workflow(
156    stages: list['StageDecorator'] | None = None,
157    wait: bool | None = False,
158    dry_run: bool = False,
159) -> 'Workflow':
160    wfl = get_workflow(dry_run=dry_run)
161    wfl.run(stages=stages, wait=wait)
162    return wfl
class Workflow:
165class Workflow:
166    """
167    Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups.
168    Responsible for orchestrating stages.
169    """
170
171    def __init__(
172        self,
173        stages: list['StageDecorator'] | None = None,
174        dry_run: bool | None = None,
175    ):
176        if _workflow is not None:
177            raise ValueError(
178                'Workflow already initialised. Use get_workflow() to get the instance',
179            )
180
181        self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run')
182        self.show_workflow = get_config()['workflow'].get('show_workflow', False)
183        self.access_level = get_config()['workflow'].get('access_level', 'test')
184
185        # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset?
186        analysis_dataset = get_config(True)['workflow']['dataset']
187        name = get_config()['workflow'].get('name', analysis_dataset)
188        description = get_config()['workflow'].get('description', name)
189        self.name = slugify(name)
190
191        self._output_version: str | None = None
192        if output_version := get_config()['workflow'].get('output_version'):
193            self._output_version = slugify(output_version)
194
195        self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp()
196
197        # Description
198        if self._output_version:
199            description += f': output_version={self._output_version}'
200        description += f': run_timestamp={self.run_timestamp}'
201        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
202            description += f' [{sequencing_type}]'
203        if not self.dry_run:
204            if ds_set := set(d.name for d in get_multicohort().get_datasets()):
205                description += ' ' + ', '.join(sorted(ds_set))
206            reset_batch()
207            get_batch().name = description
208
209        self.status_reporter = None
210        if get_config()['workflow'].get('status_reporter') == 'metamist':
211            self.status_reporter = MetamistStatusReporter()
212        self._stages: list[StageDecorator] | None = stages
213        self.queued_stages: list[Stage] = []
214
215    @property
216    def output_version(self) -> str:
217        return self._output_version or get_multicohort().get_alignment_inputs_hash()
218
219    @property
220    def analysis_prefix(self) -> Path:
221        return self._prefix(category='analysis')
222
223    @property
224    def tmp_prefix(self) -> Path:
225        return self._prefix(category='tmp')
226
227    @property
228    def web_prefix(self) -> Path:
229        return self._prefix(category='web')
230
231    @property
232    def prefix(self) -> Path:
233        return self._prefix()
234
235    def _prefix(self, category: str | None = None) -> Path:
236        """
237        Prepare a unique path for the workflow with this name and this input data.
238        """
239        return get_multicohort().analysis_dataset.prefix(category=category) / self.name / self.output_version
240
241    def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path:
242        """
243        Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method
244        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID
245        e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"
246
247        Args:
248            cohort (Cohort): we pull the analysis dataset and id from this Cohort
249            category (str | None): sub-bucket for this project
250
251        Returns:
252            Path
253        """
254        return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.id
255
256    def run(
257        self,
258        stages: list['StageDecorator'] | None = None,
259        wait: bool | None = False,
260    ):
261        """
262        Resolve stages, add and submit Hail Batch jobs.
263        When `run_all_implicit_stages` is set, all required stages that were not defined
264        explicitly would still be executed.
265        """
266        _stages = stages or self._stages
267        if not _stages:
268            raise WorkflowError('No stages added')
269        self.set_stages(_stages)
270
271        if not self.dry_run:
272            get_batch().run(wait=wait)
273
274    @staticmethod
275    def _process_first_last_stages(
276        stages: list['Stage'],
277        graph: nx.DiGraph,
278        first_stages: list[str],
279        last_stages: list[str],
280    ):
281        """
282        Applying first_stages and last_stages config options. Would skip all stages
283        before first_stages, and all stages after last_stages (i.e. descendants and
284        ancestors on the stages DAG.)
285        """
286        stages_d = {s.name: s for s in stages}
287        stage_names = list(stg.name for stg in stages)
288        lower_names = {s.lower() for s in stage_names}
289
290        for param, _stage_list in [
291            ('first_stages', first_stages),
292            ('last_stages', last_stages),
293        ]:
294            for _s_name in _stage_list:
295                if _s_name.lower() not in lower_names:
296                    raise WorkflowError(
297                        f'Value in workflow/{param} "{_s_name}" must be a stage name '
298                        f"or a subset of stages from the available list: "
299                        f'{", ".join(stage_names)}',
300                    )
301
302        if not (last_stages or first_stages):
303            return
304
305        # E.g. if our last_stages is CramQc, MtToEs would still run because it's in
306        # a different branch. So we want to collect all stages after first_stages
307        # and before last_stages in their respective branches, and mark as skipped
308        # everything in other branches.
309        first_stages_keeps: list[str] = first_stages[:]
310        last_stages_keeps: list[str] = last_stages[:]
311
312        for fs in first_stages:
313            for descendant in nx.descendants(graph, fs):
314                if not stages_d[descendant].skipped:
315                    LOGGER.info(
316                        f'Skipping stage {descendant} (precedes {fs} listed in first_stages)',
317                    )
318                    stages_d[descendant].skipped = True
319                for grand_descendant in nx.descendants(graph, descendant):
320                    if not stages_d[grand_descendant].assume_outputs_exist:
321                        LOGGER.info(
322                            f'Not checking expected outputs of not immediately '
323                            f'required stage {grand_descendant} (< {descendant} < {fs})',
324                        )
325                        stages_d[grand_descendant].assume_outputs_exist = True
326
327            for ancestor in nx.ancestors(graph, fs):
328                first_stages_keeps.append(ancestor)
329
330        for ls in last_stages:
331            # ancestors of this last_stage
332            ancestors = nx.ancestors(graph, ls)
333            if any(anc in last_stages for anc in ancestors):
334                # a downstream stage is also in last_stages, so this is not yet
335                # a "real" last stage that we want to run
336                continue
337            for ancestor in ancestors:
338                if stages_d[ancestor].skipped:
339                    continue  # already skipped
340                LOGGER.info(f'Skipping stage {ancestor} (after last {ls})')
341                stages_d[ancestor].skipped = True
342                stages_d[ancestor].assume_outputs_exist = True
343
344            for ancestor in nx.descendants(graph, ls):
345                last_stages_keeps.append(ancestor)
346
347        for _stage in stages:
348            if _stage.name not in last_stages_keeps + first_stages_keeps:
349                _stage.skipped = True
350                _stage.assume_outputs_exist = True
351
352        for stage in stages:
353            if stage.skipped:
354                graph.nodes[stage.name]['skipped'] = True
355
356    @staticmethod
357    def _process_only_stages(
358        stages: list['Stage'],
359        graph: nx.DiGraph,
360        only_stages: list[str],
361    ):
362        if not only_stages:
363            return
364
365        stages_d = {s.name: s for s in stages}
366        stage_names = list(stg.name for stg in stages)
367        lower_names = {s.lower() for s in stage_names}
368
369        for s_name in only_stages:
370            if s_name.lower() not in lower_names:
371                raise WorkflowError(
372                    f'Value in workflow/only_stages "{s_name}" must be a stage '
373                    f"name or a subset of stages from the available list: "
374                    f'{", ".join(stage_names)}',
375                )
376
377        # We want to run stages only appearing in only_stages, and check outputs of
378        # imediate predecessor stages, but skip everything else.
379        required_stages: set[str] = set()
380        for os in only_stages:
381            rs = nx.descendants_at_distance(graph, os, 1)
382            required_stages |= set(rs)
383
384        for stage in stages:
385            # Skip stage not in only_stages, and assume outputs exist...
386            if stage.name not in only_stages:
387                stage.skipped = True
388                stage.assume_outputs_exist = True
389
390        # ...unless stage is directly required by any stage in only_stages
391        for stage_name in required_stages:
392            stages_d[stage_name].assume_outputs_exist = False
393
394    def set_stages(
395        self,
396        requested_stages: list['StageDecorator'],
397    ):
398        """
399        Iterate over stages and call their queue_for_cohort(cohort) methods;
400        through that, creates all Hail Batch jobs through Stage.queue_jobs().
401        """
402        # TOML options to configure stages:
403        skip_stages = get_config()['workflow'].get('skip_stages', [])
404        only_stages = get_config()['workflow'].get('only_stages', [])
405        first_stages = get_config()['workflow'].get('first_stages', [])
406        last_stages = get_config()['workflow'].get('last_stages', [])
407
408        # Only allow one of only_stages or first_stages/last_stages as they seem
409        # to be mutually exclusive.
410        if only_stages and (first_stages or last_stages or skip_stages):
411            raise WorkflowError(
412                "Workflow config parameter 'only_stages' is incompatible with "
413                + "'first_stages', 'last_stages' and/or 'skip_stages'",
414            )
415
416        LOGGER.info(
417            f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}',
418        )
419        LOGGER.info('Stages additional configuration:')
420        LOGGER.info(f'  workflow/skip_stages: {skip_stages}')
421        LOGGER.info(f'  workflow/only_stages: {only_stages}')
422        LOGGER.info(f'  workflow/first_stages: {first_stages}')
423        LOGGER.info(f'  workflow/last_stages: {last_stages}')
424
425        # Round 1: initialising stage objects.
426        _stages_d: dict[str, Stage] = {}
427        for cls in requested_stages:
428            if cls.__name__ in _stages_d:
429                continue
430            _stages_d[cls.__name__] = cls()
431
432        # Round 2: depth search to find implicit stages.
433        depth = 0
434        while True:  # might require few iterations to resolve dependencies recursively
435            depth += 1
436            newly_implicitly_added_d = dict()
437            for stg in _stages_d.values():
438                if stg.name in skip_stages:
439                    stg.skipped = True
440                    continue  # not searching deeper
441
442                if only_stages and stg.name not in only_stages:
443                    stg.skipped = True
444
445                # Iterate dependencies:
446                for reqcls in stg.required_stages_classes:
447                    if reqcls.__name__ in _stages_d:  # already added
448                        continue
449                    # Initialising and adding as explicit.
450                    reqstg = reqcls()
451                    newly_implicitly_added_d[reqstg.name] = reqstg
452
453            if newly_implicitly_added_d:
454                LOGGER.info(
455                    f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}',
456                )
457                _stages_d |= newly_implicitly_added_d
458            else:
459                # No new implicit stages added, so can stop the depth-search here
460                break
461
462        # Round 3: set "stage.required_stages" fields to each stage.
463        for stg in _stages_d.values():
464            stg.required_stages = [
465                _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d
466            ]
467
468        # Round 4: determining order of execution.
469        dag_node2nodes = dict()  # building a DAG
470        for stg in _stages_d.values():
471            dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages)
472        dag = nx.DiGraph(dag_node2nodes)
473
474        try:
475            stage_names = list(reversed(list(nx.topological_sort(dag))))
476        except nx.NetworkXUnfeasible:
477            LOGGER.error('Circular dependencies found between stages')
478            raise
479
480        LOGGER.info(f'Stages in order of execution:\n{stage_names}')
481        stages = [_stages_d[name] for name in stage_names]
482
483        # Set order attribute to stages
484        nx.set_node_attributes(dag, {s.name: num for num, s in enumerate(stages)}, name='order')
485
486        # Update dag with the skipped attribute so it can be updated in self._process_first_last_stages
487        nx.set_node_attributes(dag, False, name='skipped')
488
489        # Round 5: applying workflow options first_stages and last_stages.
490        if first_stages or last_stages:
491            LOGGER.info('Applying workflow/first_stages and workflow/last_stages')
492            self._process_first_last_stages(stages, dag, first_stages, last_stages)
493        elif only_stages:
494            LOGGER.info('Applying workflow/only_stages')
495            self._process_only_stages(stages, dag, only_stages)
496
497        if not (final_set_of_stages := [s.name for s in stages if not s.skipped]):
498            raise WorkflowError('No stages to run')
499
500        LOGGER.info(
501            f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}',
502        )
503
504        required_skipped_stages = [s for s in stages if s.skipped]
505        if required_skipped_stages:
506            LOGGER.info(
507                f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}',
508            )
509
510        # Round 6: actually adding jobs from the stages.
511        if not self.dry_run:
512            inputs = get_multicohort()  # Would communicate with metamist.
513            for i, stg in enumerate(stages):
514                LOGGER.info('*' * 60)
515                LOGGER.info(f'Stage #{i + 1}: {stg}')
516                # pipeline setup is now done in MultiCohort only
517                # the legacy version (input_datasets) is still supported
518                # that will create a MultiCohort with a single Cohort
519                if isinstance(inputs, MultiCohort):
520                    stg.output_by_target = stg.queue_for_multicohort(inputs)
521                else:
522                    raise WorkflowError(f'Unsupported input type: {inputs}')
523                if errors := self._process_stage_errors(stg.output_by_target):
524                    raise WorkflowError(
525                        f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors),
526                    )
527
528                LOGGER.info('')
529
530        else:
531            self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped]
532            LOGGER.info(f'Queued stages: {self.queued_stages}')
533
534        # Round 7: show the workflow
535
536        def format_meta(attr: list):
537            return {s: s in attr for s in dag.nodes}
538
539        # First add remaining metadata
540        nx.set_node_attributes(dag, format_meta(skip_stages), name='skip_stages')
541        nx.set_node_attributes(dag, format_meta(only_stages), name='only_stages')
542        nx.set_node_attributes(dag, format_meta(first_stages), name='first_stages')
543        nx.set_node_attributes(dag, format_meta(last_stages), name='last_stages')
544
545        if self.show_workflow:
546            gp = GraphPlot(dag, title='Full Workflow Graph')
547
548            # Removed skipped steps for simple graph
549            all_nodes = list(dag.nodes)
550            _ = [dag.remove_node(n) for n in all_nodes if dag.nodes[n]['skipped']]
551            gp2 = GraphPlot(dag, title='Sub-Workflow Graph')
552
553            fig = gp + gp2
554
555            # Show the figure
556            fig.show()
557
558            # If we have a web bucket path
559            try:
560                if web_prefix := self.web_prefix:
561                    html_path = web_prefix / f'{self.name}_workflow.html'
562                    if str(html_path).startswith('gs:/'):
563                        html_file = pio.to_html(fig, full_html=True)
564                        _, file_path = write_to_gcs_bucket(html_file, html_path)
565                        url = URL_BASENAME.format(access_level=self.access_level, name=self.name) + str(file_path)
566
567                        LOGGER.info(f'Link to the graph: {url}')
568                    else:
569                        pio.write_html(fig, file=str(html_path), auto_open=False)
570
571                    LOGGER.info(f'Workflow graph saved to {html_path}')
572
573            except ConnectionError as e:
574                LOGGER.error(f'Failed to save workflow graph: {e}')
575
576    @staticmethod
577    def _process_stage_errors(
578        output_by_target: dict[str, Union['StageOutput', None]],
579    ) -> list[str]:
580        targets_by_error = defaultdict(list)
581        for target, output in output_by_target.items():
582            if output and output.error_msg:
583                targets_by_error[output.error_msg].append(target)
584        return [f'{error}: {", ".join(target_ids)}' for error, target_ids in targets_by_error.items()]

Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. Responsible for orchestrating stages.

Workflow( stages: list[Callable[..., cpg_flow.stage.Stage]] | None = None, dry_run: bool | None = None)
171    def __init__(
172        self,
173        stages: list['StageDecorator'] | None = None,
174        dry_run: bool | None = None,
175    ):
176        if _workflow is not None:
177            raise ValueError(
178                'Workflow already initialised. Use get_workflow() to get the instance',
179            )
180
181        self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run')
182        self.show_workflow = get_config()['workflow'].get('show_workflow', False)
183        self.access_level = get_config()['workflow'].get('access_level', 'test')
184
185        # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset?
186        analysis_dataset = get_config(True)['workflow']['dataset']
187        name = get_config()['workflow'].get('name', analysis_dataset)
188        description = get_config()['workflow'].get('description', name)
189        self.name = slugify(name)
190
191        self._output_version: str | None = None
192        if output_version := get_config()['workflow'].get('output_version'):
193            self._output_version = slugify(output_version)
194
195        self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp()
196
197        # Description
198        if self._output_version:
199            description += f': output_version={self._output_version}'
200        description += f': run_timestamp={self.run_timestamp}'
201        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
202            description += f' [{sequencing_type}]'
203        if not self.dry_run:
204            if ds_set := set(d.name for d in get_multicohort().get_datasets()):
205                description += ' ' + ', '.join(sorted(ds_set))
206            reset_batch()
207            get_batch().name = description
208
209        self.status_reporter = None
210        if get_config()['workflow'].get('status_reporter') == 'metamist':
211            self.status_reporter = MetamistStatusReporter()
212        self._stages: list[StageDecorator] | None = stages
213        self.queued_stages: list[Stage] = []
dry_run
show_workflow
access_level
name
run_timestamp: str
status_reporter
queued_stages: list[cpg_flow.stage.Stage]
output_version: str
215    @property
216    def output_version(self) -> str:
217        return self._output_version or get_multicohort().get_alignment_inputs_hash()
analysis_prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
219    @property
220    def analysis_prefix(self) -> Path:
221        return self._prefix(category='analysis')
tmp_prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
223    @property
224    def tmp_prefix(self) -> Path:
225        return self._prefix(category='tmp')
web_prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
227    @property
228    def web_prefix(self) -> Path:
229        return self._prefix(category='web')
prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
231    @property
232    def prefix(self) -> Path:
233        return self._prefix()
def cohort_prefix( self, cohort: cpg_flow.targets.cohort.Cohort, category: str | None = None) -> cloudpathlib.cloudpath.CloudPath | pathlib.Path:
241    def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path:
242        """
243        Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method
244        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID
245        e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"
246
247        Args:
248            cohort (Cohort): we pull the analysis dataset and id from this Cohort
249            category (str | None): sub-bucket for this project
250
251        Returns:
252            Path
253        """
254        return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.id

Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"

Args: cohort (Cohort): we pull the analysis dataset and id from this Cohort category (str | None): sub-bucket for this project

Returns: Path

def run( self, stages: list[Callable[..., cpg_flow.stage.Stage]] | None = None, wait: bool | None = False):
256    def run(
257        self,
258        stages: list['StageDecorator'] | None = None,
259        wait: bool | None = False,
260    ):
261        """
262        Resolve stages, add and submit Hail Batch jobs.
263        When `run_all_implicit_stages` is set, all required stages that were not defined
264        explicitly would still be executed.
265        """
266        _stages = stages or self._stages
267        if not _stages:
268            raise WorkflowError('No stages added')
269        self.set_stages(_stages)
270
271        if not self.dry_run:
272            get_batch().run(wait=wait)

Resolve stages, add and submit Hail Batch jobs. When run_all_implicit_stages is set, all required stages that were not defined explicitly would still be executed.

def set_stages(self, requested_stages: list[Callable[..., cpg_flow.stage.Stage]]):
394    def set_stages(
395        self,
396        requested_stages: list['StageDecorator'],
397    ):
398        """
399        Iterate over stages and call their queue_for_cohort(cohort) methods;
400        through that, creates all Hail Batch jobs through Stage.queue_jobs().
401        """
402        # TOML options to configure stages:
403        skip_stages = get_config()['workflow'].get('skip_stages', [])
404        only_stages = get_config()['workflow'].get('only_stages', [])
405        first_stages = get_config()['workflow'].get('first_stages', [])
406        last_stages = get_config()['workflow'].get('last_stages', [])
407
408        # Only allow one of only_stages or first_stages/last_stages as they seem
409        # to be mutually exclusive.
410        if only_stages and (first_stages or last_stages or skip_stages):
411            raise WorkflowError(
412                "Workflow config parameter 'only_stages' is incompatible with "
413                + "'first_stages', 'last_stages' and/or 'skip_stages'",
414            )
415
416        LOGGER.info(
417            f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}',
418        )
419        LOGGER.info('Stages additional configuration:')
420        LOGGER.info(f'  workflow/skip_stages: {skip_stages}')
421        LOGGER.info(f'  workflow/only_stages: {only_stages}')
422        LOGGER.info(f'  workflow/first_stages: {first_stages}')
423        LOGGER.info(f'  workflow/last_stages: {last_stages}')
424
425        # Round 1: initialising stage objects.
426        _stages_d: dict[str, Stage] = {}
427        for cls in requested_stages:
428            if cls.__name__ in _stages_d:
429                continue
430            _stages_d[cls.__name__] = cls()
431
432        # Round 2: depth search to find implicit stages.
433        depth = 0
434        while True:  # might require few iterations to resolve dependencies recursively
435            depth += 1
436            newly_implicitly_added_d = dict()
437            for stg in _stages_d.values():
438                if stg.name in skip_stages:
439                    stg.skipped = True
440                    continue  # not searching deeper
441
442                if only_stages and stg.name not in only_stages:
443                    stg.skipped = True
444
445                # Iterate dependencies:
446                for reqcls in stg.required_stages_classes:
447                    if reqcls.__name__ in _stages_d:  # already added
448                        continue
449                    # Initialising and adding as explicit.
450                    reqstg = reqcls()
451                    newly_implicitly_added_d[reqstg.name] = reqstg
452
453            if newly_implicitly_added_d:
454                LOGGER.info(
455                    f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}',
456                )
457                _stages_d |= newly_implicitly_added_d
458            else:
459                # No new implicit stages added, so can stop the depth-search here
460                break
461
462        # Round 3: set "stage.required_stages" fields to each stage.
463        for stg in _stages_d.values():
464            stg.required_stages = [
465                _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d
466            ]
467
468        # Round 4: determining order of execution.
469        dag_node2nodes = dict()  # building a DAG
470        for stg in _stages_d.values():
471            dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages)
472        dag = nx.DiGraph(dag_node2nodes)
473
474        try:
475            stage_names = list(reversed(list(nx.topological_sort(dag))))
476        except nx.NetworkXUnfeasible:
477            LOGGER.error('Circular dependencies found between stages')
478            raise
479
480        LOGGER.info(f'Stages in order of execution:\n{stage_names}')
481        stages = [_stages_d[name] for name in stage_names]
482
483        # Set order attribute to stages
484        nx.set_node_attributes(dag, {s.name: num for num, s in enumerate(stages)}, name='order')
485
486        # Update dag with the skipped attribute so it can be updated in self._process_first_last_stages
487        nx.set_node_attributes(dag, False, name='skipped')
488
489        # Round 5: applying workflow options first_stages and last_stages.
490        if first_stages or last_stages:
491            LOGGER.info('Applying workflow/first_stages and workflow/last_stages')
492            self._process_first_last_stages(stages, dag, first_stages, last_stages)
493        elif only_stages:
494            LOGGER.info('Applying workflow/only_stages')
495            self._process_only_stages(stages, dag, only_stages)
496
497        if not (final_set_of_stages := [s.name for s in stages if not s.skipped]):
498            raise WorkflowError('No stages to run')
499
500        LOGGER.info(
501            f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}',
502        )
503
504        required_skipped_stages = [s for s in stages if s.skipped]
505        if required_skipped_stages:
506            LOGGER.info(
507                f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}',
508            )
509
510        # Round 6: actually adding jobs from the stages.
511        if not self.dry_run:
512            inputs = get_multicohort()  # Would communicate with metamist.
513            for i, stg in enumerate(stages):
514                LOGGER.info('*' * 60)
515                LOGGER.info(f'Stage #{i + 1}: {stg}')
516                # pipeline setup is now done in MultiCohort only
517                # the legacy version (input_datasets) is still supported
518                # that will create a MultiCohort with a single Cohort
519                if isinstance(inputs, MultiCohort):
520                    stg.output_by_target = stg.queue_for_multicohort(inputs)
521                else:
522                    raise WorkflowError(f'Unsupported input type: {inputs}')
523                if errors := self._process_stage_errors(stg.output_by_target):
524                    raise WorkflowError(
525                        f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors),
526                    )
527
528                LOGGER.info('')
529
530        else:
531            self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped]
532            LOGGER.info(f'Queued stages: {self.queued_stages}')
533
534        # Round 7: show the workflow
535
536        def format_meta(attr: list):
537            return {s: s in attr for s in dag.nodes}
538
539        # First add remaining metadata
540        nx.set_node_attributes(dag, format_meta(skip_stages), name='skip_stages')
541        nx.set_node_attributes(dag, format_meta(only_stages), name='only_stages')
542        nx.set_node_attributes(dag, format_meta(first_stages), name='first_stages')
543        nx.set_node_attributes(dag, format_meta(last_stages), name='last_stages')
544
545        if self.show_workflow:
546            gp = GraphPlot(dag, title='Full Workflow Graph')
547
548            # Removed skipped steps for simple graph
549            all_nodes = list(dag.nodes)
550            _ = [dag.remove_node(n) for n in all_nodes if dag.nodes[n]['skipped']]
551            gp2 = GraphPlot(dag, title='Sub-Workflow Graph')
552
553            fig = gp + gp2
554
555            # Show the figure
556            fig.show()
557
558            # If we have a web bucket path
559            try:
560                if web_prefix := self.web_prefix:
561                    html_path = web_prefix / f'{self.name}_workflow.html'
562                    if str(html_path).startswith('gs:/'):
563                        html_file = pio.to_html(fig, full_html=True)
564                        _, file_path = write_to_gcs_bucket(html_file, html_path)
565                        url = URL_BASENAME.format(access_level=self.access_level, name=self.name) + str(file_path)
566
567                        LOGGER.info(f'Link to the graph: {url}')
568                    else:
569                        pio.write_html(fig, file=str(html_path), auto_open=False)
570
571                    LOGGER.info(f'Workflow graph saved to {html_path}')
572
573            except ConnectionError as e:
574                LOGGER.error(f'Failed to save workflow graph: {e}')

Iterate over stages and call their queue_for_cohort(cohort) methods; through that, creates all Hail Batch jobs through Stage.queue_jobs().