cpg_flow.workflow

Provides a Workflow class and a @stage decorator that allow to define workflows in a declarative fashion.

A Stage object is responsible for creating Hail Batch jobs and declaring outputs (files or metamist analysis objects) that are expected to be produced. Each stage acts on a Target, which can be of the following:

* SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample)
* Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care)
* Cohort - a stratification of SGs in this analysis by Metamist CustomCohort
* MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort

A Workflow object plugs stages together by resolving dependencies between different levels accordingly. Stages are defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in main.py, which provides a way to choose a workflow using a CLI argument.

  1"""
  2Provides a `Workflow` class and a `@stage` decorator that allow to define workflows
  3in a declarative fashion.
  4
  5A `Stage` object is responsible for creating Hail Batch jobs and declaring outputs
  6(files or metamist analysis objects) that are expected to be produced. Each stage
  7acts on a `Target`, which can be of the following:
  8
  9    * SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample)
 10    * Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care)
 11    * Cohort - a stratification of SGs in this analysis by Metamist CustomCohort
 12    * MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort
 13
 14A `Workflow` object plugs stages together by resolving dependencies between different levels accordingly. Stages are
 15defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in
 16main.py, which provides a way to choose a workflow using a CLI argument.
 17"""
 18
 19import functools
 20from collections import defaultdict
 21from collections.abc import Callable
 22from enum import Enum
 23from typing import TYPE_CHECKING, Optional, Union
 24
 25import networkx as nx
 26
 27from cpg_flow.inputs import get_multicohort
 28from cpg_flow.status import MetamistStatusReporter
 29from cpg_flow.targets import Cohort, MultiCohort
 30from cpg_flow.utils import get_logger, slugify, timestamp
 31from cpg_utils import Path
 32from cpg_utils.config import get_config
 33from cpg_utils.hail_batch import get_batch, reset_batch
 34
 35LOGGER = get_logger(__name__)
 36
 37if TYPE_CHECKING:
 38    from cpg_flow.stage import Stage, StageDecorator, StageOutput
 39
 40
 41def path_walk(expected, collected: set | None = None) -> set[Path]:
 42    """
 43    recursive walk of expected_out
 44    if the object is iterable, walk it
 45    this gets around the issue with nested lists and dicts
 46    mainly around the use of Array outputs from Cromwell
 47
 48    Args:
 49        expected (Any): any type of object containing Paths
 50        collected (set): all collected paths so far
 51
 52    Returns:
 53        a set of all collected Path nodes
 54
 55    Examples:
 56
 57    >>> path_walk({'a': {'b': {'c': Path('d')}}})
 58    {Path('d')}
 59    >>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}})
 60    {Path('d'), Path('e')}
 61    >>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}})
 62    {Path('b'), Path('g')}
 63    """
 64    if collected is None:
 65        collected = set()
 66
 67    if expected is None:
 68        return collected
 69    if isinstance(expected, dict):
 70        for value in expected.values():
 71            collected.update(path_walk(value, collected))
 72    if isinstance(expected, list | set):
 73        for value in expected:
 74            collected.update(path_walk(value, collected))
 75    if isinstance(expected, str):
 76        return collected
 77    if isinstance(expected, Path):
 78        if expected in collected:
 79            raise ValueError(f'Duplicate path {expected} in expected_out')
 80        collected.add(expected)
 81    return collected
 82
 83
 84class WorkflowError(Exception):
 85    """
 86    Error raised by workflow and stage implementation.
 87    """
 88
 89
 90class Action(Enum):
 91    """
 92    Indicates what a stage should do with a specific target.
 93    """
 94
 95    QUEUE = 1
 96    SKIP = 2
 97    REUSE = 3
 98
 99
100# noinspection PyUnusedLocal
101def skip(
102    _fun: Optional['StageDecorator'] = None,
103    *,
104    reason: str | None = None,
105    assume_outputs_exist: bool = False,
106) -> Union['StageDecorator', Callable[..., 'StageDecorator']]:
107    """
108    Decorator on top of `@stage` that sets the `self.skipped` field to True.
109    By default, expected outputs of a skipped stage will be checked,
110    unless `assume_outputs_exist` is True.
111
112    @skip
113    @stage
114    class MyStage1(SequencingGroupStage):
115        ...
116
117    @skip
118    @stage(assume_outputs_exist=True)
119    class MyStage2(SequencingGroupStage):
120        ...
121    """
122
123    def decorator_stage(fun) -> 'StageDecorator':
124        """Implements decorator."""
125
126        @functools.wraps(fun)
127        def wrapper_stage(*args, **kwargs) -> 'Stage':
128            """Decorator helper function."""
129            s = fun(*args, **kwargs)
130            s.skipped = True
131            s.assume_outputs_exist = assume_outputs_exist
132            return s
133
134        return wrapper_stage
135
136    if _fun is None:
137        return decorator_stage
138    return decorator_stage(_fun)
139
140
141_workflow: Optional['Workflow'] = None
142
143
144def get_workflow(dry_run: bool = False) -> 'Workflow':
145    global _workflow
146    if _workflow is None:
147        _workflow = Workflow(dry_run=dry_run)
148    return _workflow
149
150
151def run_workflow(
152    stages: list['StageDecorator'] | None = None,
153    wait: bool | None = False,
154    dry_run: bool = False,
155) -> 'Workflow':
156    wfl = get_workflow(dry_run=dry_run)
157    wfl.run(stages=stages, wait=wait)
158    return wfl
159
160
161class Workflow:
162    """
163    Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups.
164    Responsible for orchestrating stages.
165    """
166
167    def __init__(
168        self,
169        stages: list['StageDecorator'] | None = None,
170        dry_run: bool | None = None,
171    ):
172        if _workflow is not None:
173            raise ValueError(
174                'Workflow already initialised. Use get_workflow() to get the instance',
175            )
176
177        self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run')
178
179        # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset?
180        analysis_dataset = get_config(True)['workflow']['dataset']
181        name = get_config()['workflow'].get('name', analysis_dataset)
182        description = get_config()['workflow'].get('description', name)
183        self.name = slugify(name)
184
185        self._output_version: str | None = None
186        if output_version := get_config()['workflow'].get('output_version'):
187            self._output_version = slugify(output_version)
188
189        self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp()
190
191        # Description
192        if self._output_version:
193            description += f': output_version={self._output_version}'
194        description += f': run_timestamp={self.run_timestamp}'
195        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
196            description += f' [{sequencing_type}]'
197        if not self.dry_run:
198            if ds_set := set(d.name for d in get_multicohort().get_datasets()):
199                description += ' ' + ', '.join(sorted(ds_set))
200            reset_batch()
201            get_batch().name = description
202
203        self.status_reporter = None
204        if get_config()['workflow'].get('status_reporter') == 'metamist':
205            self.status_reporter = MetamistStatusReporter()
206        self._stages: list[StageDecorator] | None = stages
207        self.queued_stages: list[Stage] = []
208
209    @property
210    def output_version(self) -> str:
211        return self._output_version or get_multicohort().alignment_inputs_hash()
212
213    @property
214    def analysis_prefix(self) -> Path:
215        return self._prefix(category='analysis')
216
217    @property
218    def tmp_prefix(self) -> Path:
219        return self._prefix(category='tmp')
220
221    @property
222    def web_prefix(self) -> Path:
223        return self._prefix(category='web')
224
225    @property
226    def prefix(self) -> Path:
227        return self._prefix()
228
229    def _prefix(self, category: str | None = None) -> Path:
230        """
231        Prepare a unique path for the workflow with this name and this input data.
232        """
233        return get_multicohort().analysis_dataset.prefix(category=category) / self.name / self.output_version
234
235    def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path:
236        """
237        Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method
238        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID
239        e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"
240
241        Args:
242            cohort (Cohort): we pull the analysis dataset and name from this Cohort
243            category (str | None): sub-bucket for this project
244
245        Returns:
246            Path
247        """
248        return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.name
249
250    def run(
251        self,
252        stages: list['StageDecorator'] | None = None,
253        wait: bool | None = False,
254    ):
255        """
256        Resolve stages, add and submit Hail Batch jobs.
257        When `run_all_implicit_stages` is set, all required stages that were not defined
258        explicitly would still be executed.
259        """
260        _stages = stages or self._stages
261        if not _stages:
262            raise WorkflowError('No stages added')
263        self.set_stages(_stages)
264
265        if not self.dry_run:
266            get_batch().run(wait=wait)
267
268    @staticmethod
269    def _process_first_last_stages(
270        stages: list['Stage'],
271        graph: nx.DiGraph,
272        first_stages: list[str],
273        last_stages: list[str],
274    ):
275        """
276        Applying first_stages and last_stages config options. Would skip all stages
277        before first_stages, and all stages after last_stages (i.e. descendants and
278        ancestors on the stages DAG.)
279        """
280        stages_d = {s.name: s for s in stages}
281        stage_names = list(stg.name for stg in stages)
282        lower_names = {s.lower() for s in stage_names}
283
284        for param, _stage_list in [
285            ('first_stages', first_stages),
286            ('last_stages', last_stages),
287        ]:
288            for _s_name in _stage_list:
289                if _s_name.lower() not in lower_names:
290                    raise WorkflowError(
291                        f'Value in workflow/{param} "{_s_name}" must be a stage name '
292                        f"or a subset of stages from the available list: "
293                        f'{", ".join(stage_names)}',
294                    )
295
296        if not (last_stages or first_stages):
297            return
298
299        # E.g. if our last_stages is CramQc, MtToEs would still run because it's in
300        # a different branch. So we want to collect all stages after first_stages
301        # and before last_stages in their respective branches, and mark as skipped
302        # everything in other branches.
303        first_stages_keeps: list[str] = first_stages[:]
304        last_stages_keeps: list[str] = last_stages[:]
305
306        for fs in first_stages:
307            for descendant in nx.descendants(graph, fs):
308                if not stages_d[descendant].skipped:
309                    LOGGER.info(
310                        f'Skipping stage {descendant} (precedes {fs} listed in first_stages)',
311                    )
312                    stages_d[descendant].skipped = True
313                for grand_descendant in nx.descendants(graph, descendant):
314                    if not stages_d[grand_descendant].assume_outputs_exist:
315                        LOGGER.info(
316                            f'Not checking expected outputs of not immediately '
317                            f'required stage {grand_descendant} (< {descendant} < {fs})',
318                        )
319                        stages_d[grand_descendant].assume_outputs_exist = True
320
321            for ancestor in nx.ancestors(graph, fs):
322                first_stages_keeps.append(ancestor)
323
324        for ls in last_stages:
325            # ancestors of this last_stage
326            ancestors = nx.ancestors(graph, ls)
327            if any(anc in last_stages for anc in ancestors):
328                # a downstream stage is also in last_stages, so this is not yet
329                # a "real" last stage that we want to run
330                continue
331            for ancestor in ancestors:
332                if stages_d[ancestor].skipped:
333                    continue  # already skipped
334                LOGGER.info(f'Skipping stage {ancestor} (after last {ls})')
335                stages_d[ancestor].skipped = True
336                stages_d[ancestor].assume_outputs_exist = True
337
338            for ancestor in nx.descendants(graph, ls):
339                last_stages_keeps.append(ancestor)
340
341        for _stage in stages:
342            if _stage.name not in last_stages_keeps + first_stages_keeps:
343                _stage.skipped = True
344                _stage.assume_outputs_exist = True
345
346    @staticmethod
347    def _process_only_stages(
348        stages: list['Stage'],
349        graph: nx.DiGraph,
350        only_stages: list[str],
351    ):
352        if not only_stages:
353            return
354
355        stages_d = {s.name: s for s in stages}
356        stage_names = list(stg.name for stg in stages)
357        lower_names = {s.lower() for s in stage_names}
358
359        for s_name in only_stages:
360            if s_name.lower() not in lower_names:
361                raise WorkflowError(
362                    f'Value in workflow/only_stages "{s_name}" must be a stage '
363                    f"name or a subset of stages from the available list: "
364                    f'{", ".join(stage_names)}',
365                )
366
367        # We want to run stages only appearing in only_stages, and check outputs of
368        # imediate predecessor stages, but skip everything else.
369        required_stages: set[str] = set()
370        for os in only_stages:
371            rs = nx.descendants_at_distance(graph, os, 1)
372            required_stages |= set(rs)
373
374        for stage in stages:
375            # Skip stage not in only_stages, and assume outputs exist...
376            if stage.name not in only_stages:
377                stage.skipped = True
378                stage.assume_outputs_exist = True
379
380        # ...unless stage is directly required by any stage in only_stages
381        for stage_name in required_stages:
382            stages_d[stage_name].assume_outputs_exist = False
383
384    def set_stages(
385        self,
386        requested_stages: list['StageDecorator'],
387    ):
388        """
389        Iterate over stages and call their queue_for_cohort(cohort) methods;
390        through that, creates all Hail Batch jobs through Stage.queue_jobs().
391        """
392        # TOML options to configure stages:
393        skip_stages = get_config()['workflow'].get('skip_stages', [])
394        only_stages = get_config()['workflow'].get('only_stages', [])
395        first_stages = get_config()['workflow'].get('first_stages', [])
396        last_stages = get_config()['workflow'].get('last_stages', [])
397
398        # Only allow one of only_stages or first_stages/last_stages as they seem
399        # to be mutually exclusive.
400        if only_stages and (first_stages or last_stages or skip_stages):
401            raise WorkflowError(
402                "Workflow config parameter 'only_stages' is incompatible with "
403                + "'first_stages', 'last_stages' and/or 'skip_stages'",
404            )
405
406        LOGGER.info(
407            f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}',
408        )
409        LOGGER.info('Stages additional configuration:')
410        LOGGER.info(f'  workflow/skip_stages: {skip_stages}')
411        LOGGER.info(f'  workflow/only_stages: {only_stages}')
412        LOGGER.info(f'  workflow/first_stages: {first_stages}')
413        LOGGER.info(f'  workflow/last_stages: {last_stages}')
414
415        # Round 1: initialising stage objects.
416        _stages_d: dict[str, Stage] = {}
417        for cls in requested_stages:
418            if cls.__name__ in _stages_d:
419                continue
420            _stages_d[cls.__name__] = cls()
421
422        # Round 2: depth search to find implicit stages.
423        depth = 0
424        while True:  # might require few iterations to resolve dependencies recursively
425            depth += 1
426            newly_implicitly_added_d = dict()
427            for stg in _stages_d.values():
428                if stg.name in skip_stages:
429                    stg.skipped = True
430                    continue  # not searching deeper
431
432                if only_stages and stg.name not in only_stages:
433                    stg.skipped = True
434
435                # Iterate dependencies:
436                for reqcls in stg.required_stages_classes:
437                    if reqcls.__name__ in _stages_d:  # already added
438                        continue
439                    # Initialising and adding as explicit.
440                    reqstg = reqcls()
441                    newly_implicitly_added_d[reqstg.name] = reqstg
442
443            if newly_implicitly_added_d:
444                LOGGER.info(
445                    f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}',
446                )
447                _stages_d |= newly_implicitly_added_d
448            else:
449                # No new implicit stages added, so can stop the depth-search here
450                break
451
452        # Round 3: set "stage.required_stages" fields to each stage.
453        for stg in _stages_d.values():
454            stg.required_stages = [
455                _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d
456            ]
457
458        # Round 4: determining order of execution.
459        dag_node2nodes = dict()  # building a DAG
460        for stg in _stages_d.values():
461            dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages)
462        dag = nx.DiGraph(dag_node2nodes)
463        try:
464            stage_names = list(reversed(list(nx.topological_sort(dag))))
465        except nx.NetworkXUnfeasible:
466            LOGGER.error('Circular dependencies found between stages')
467            raise
468
469        LOGGER.info(f'Stages in order of execution:\n{stage_names}')
470        stages = [_stages_d[name] for name in stage_names]
471
472        # Round 5: applying workflow options first_stages and last_stages.
473        if first_stages or last_stages:
474            LOGGER.info('Applying workflow/first_stages and workflow/last_stages')
475            self._process_first_last_stages(stages, dag, first_stages, last_stages)
476        elif only_stages:
477            LOGGER.info('Applying workflow/only_stages')
478            self._process_only_stages(stages, dag, only_stages)
479
480        if not (final_set_of_stages := [s.name for s in stages if not s.skipped]):
481            raise WorkflowError('No stages to run')
482
483        LOGGER.info(
484            f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}',
485        )
486
487        required_skipped_stages = [s for s in stages if s.skipped]
488        if required_skipped_stages:
489            LOGGER.info(
490                f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}',
491            )
492
493        # Round 6: actually adding jobs from the stages.
494        if not self.dry_run:
495            inputs = get_multicohort()  # Would communicate with metamist.
496            for i, stg in enumerate(stages):
497                LOGGER.info('*' * 60)
498                LOGGER.info(f'Stage #{i + 1}: {stg}')
499                # pipeline setup is now done in MultiCohort only
500                # the legacy version (input_datasets) is still supported
501                # that will create a MultiCohort with a single Cohort
502                if isinstance(inputs, MultiCohort):
503                    stg.output_by_target = stg.queue_for_multicohort(inputs)
504                else:
505                    raise WorkflowError(f'Unsupported input type: {inputs}')
506                if errors := self._process_stage_errors(stg.output_by_target):
507                    raise WorkflowError(
508                        f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors),
509                    )
510
511                LOGGER.info('')
512
513        else:
514            self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped]
515            LOGGER.info(f'Queued stages: {self.queued_stages}')
516
517    @staticmethod
518    def _process_stage_errors(
519        output_by_target: dict[str, Union['StageOutput', None]],
520    ) -> list[str]:
521        targets_by_error = defaultdict(list)
522        for target, output in output_by_target.items():
523            if output and output.error_msg:
524                targets_by_error[output.error_msg].append(target)
525        return [f'{error}: {", ".join(target_ids)}' for error, target_ids in targets_by_error.items()]
LOGGER = <Logger cpg_flow.workflow (INFO)>
def path_walk( expected, collected: set | None = None) -> set[cloudpathlib.cloudpath.CloudPath | pathlib.Path]:
42def path_walk(expected, collected: set | None = None) -> set[Path]:
43    """
44    recursive walk of expected_out
45    if the object is iterable, walk it
46    this gets around the issue with nested lists and dicts
47    mainly around the use of Array outputs from Cromwell
48
49    Args:
50        expected (Any): any type of object containing Paths
51        collected (set): all collected paths so far
52
53    Returns:
54        a set of all collected Path nodes
55
56    Examples:
57
58    >>> path_walk({'a': {'b': {'c': Path('d')}}})
59    {Path('d')}
60    >>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}})
61    {Path('d'), Path('e')}
62    >>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}})
63    {Path('b'), Path('g')}
64    """
65    if collected is None:
66        collected = set()
67
68    if expected is None:
69        return collected
70    if isinstance(expected, dict):
71        for value in expected.values():
72            collected.update(path_walk(value, collected))
73    if isinstance(expected, list | set):
74        for value in expected:
75            collected.update(path_walk(value, collected))
76    if isinstance(expected, str):
77        return collected
78    if isinstance(expected, Path):
79        if expected in collected:
80            raise ValueError(f'Duplicate path {expected} in expected_out')
81        collected.add(expected)
82    return collected

recursive walk of expected_out if the object is iterable, walk it this gets around the issue with nested lists and dicts mainly around the use of Array outputs from Cromwell

Args: expected (Any): any type of object containing Paths collected (set): all collected paths so far

Returns: a set of all collected Path nodes

Examples:

>>> path_walk({'a': {'b': {'c': Path('d')}}})
{Path('d')}
>>> path_walk({'a': {'b': {'c': [Path('d'), Path('e')]}}})
{Path('d'), Path('e')}
>>> path_walk({'a': Path('b'),'c': {'d': 'e'}, {'f': Path('g')}})
{Path('b'), Path('g')}
class WorkflowError(builtins.Exception):
85class WorkflowError(Exception):
86    """
87    Error raised by workflow and stage implementation.
88    """

Error raised by workflow and stage implementation.

class Action(enum.Enum):
91class Action(Enum):
92    """
93    Indicates what a stage should do with a specific target.
94    """
95
96    QUEUE = 1
97    SKIP = 2
98    REUSE = 3

Indicates what a stage should do with a specific target.

QUEUE = <Action.QUEUE: 1>
SKIP = <Action.SKIP: 2>
REUSE = <Action.REUSE: 3>
def skip( _fun: Optional[Callable[..., cpg_flow.stage.Stage]] = None, *, reason: str | None = None, assume_outputs_exist: bool = False) -> Union[Callable[..., cpg_flow.stage.Stage], Callable[..., Callable[..., cpg_flow.stage.Stage]]]:
102def skip(
103    _fun: Optional['StageDecorator'] = None,
104    *,
105    reason: str | None = None,
106    assume_outputs_exist: bool = False,
107) -> Union['StageDecorator', Callable[..., 'StageDecorator']]:
108    """
109    Decorator on top of `@stage` that sets the `self.skipped` field to True.
110    By default, expected outputs of a skipped stage will be checked,
111    unless `assume_outputs_exist` is True.
112
113    @skip
114    @stage
115    class MyStage1(SequencingGroupStage):
116        ...
117
118    @skip
119    @stage(assume_outputs_exist=True)
120    class MyStage2(SequencingGroupStage):
121        ...
122    """
123
124    def decorator_stage(fun) -> 'StageDecorator':
125        """Implements decorator."""
126
127        @functools.wraps(fun)
128        def wrapper_stage(*args, **kwargs) -> 'Stage':
129            """Decorator helper function."""
130            s = fun(*args, **kwargs)
131            s.skipped = True
132            s.assume_outputs_exist = assume_outputs_exist
133            return s
134
135        return wrapper_stage
136
137    if _fun is None:
138        return decorator_stage
139    return decorator_stage(_fun)

Decorator on top of @stage that sets the self.skipped field to True. By default, expected outputs of a skipped stage will be checked, unless assume_outputs_exist is True.

@skip @stage class MyStage1(SequencingGroupStage): ...

@skip @stage(assume_outputs_exist=True) class MyStage2(SequencingGroupStage): ...

def get_workflow(dry_run: bool = False) -> Workflow:
145def get_workflow(dry_run: bool = False) -> 'Workflow':
146    global _workflow
147    if _workflow is None:
148        _workflow = Workflow(dry_run=dry_run)
149    return _workflow
def run_workflow( stages: list[Callable[..., cpg_flow.stage.Stage]] | None = None, wait: bool | None = False, dry_run: bool = False) -> Workflow:
152def run_workflow(
153    stages: list['StageDecorator'] | None = None,
154    wait: bool | None = False,
155    dry_run: bool = False,
156) -> 'Workflow':
157    wfl = get_workflow(dry_run=dry_run)
158    wfl.run(stages=stages, wait=wait)
159    return wfl
class Workflow:
162class Workflow:
163    """
164    Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups.
165    Responsible for orchestrating stages.
166    """
167
168    def __init__(
169        self,
170        stages: list['StageDecorator'] | None = None,
171        dry_run: bool | None = None,
172    ):
173        if _workflow is not None:
174            raise ValueError(
175                'Workflow already initialised. Use get_workflow() to get the instance',
176            )
177
178        self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run')
179
180        # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset?
181        analysis_dataset = get_config(True)['workflow']['dataset']
182        name = get_config()['workflow'].get('name', analysis_dataset)
183        description = get_config()['workflow'].get('description', name)
184        self.name = slugify(name)
185
186        self._output_version: str | None = None
187        if output_version := get_config()['workflow'].get('output_version'):
188            self._output_version = slugify(output_version)
189
190        self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp()
191
192        # Description
193        if self._output_version:
194            description += f': output_version={self._output_version}'
195        description += f': run_timestamp={self.run_timestamp}'
196        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
197            description += f' [{sequencing_type}]'
198        if not self.dry_run:
199            if ds_set := set(d.name for d in get_multicohort().get_datasets()):
200                description += ' ' + ', '.join(sorted(ds_set))
201            reset_batch()
202            get_batch().name = description
203
204        self.status_reporter = None
205        if get_config()['workflow'].get('status_reporter') == 'metamist':
206            self.status_reporter = MetamistStatusReporter()
207        self._stages: list[StageDecorator] | None = stages
208        self.queued_stages: list[Stage] = []
209
210    @property
211    def output_version(self) -> str:
212        return self._output_version or get_multicohort().alignment_inputs_hash()
213
214    @property
215    def analysis_prefix(self) -> Path:
216        return self._prefix(category='analysis')
217
218    @property
219    def tmp_prefix(self) -> Path:
220        return self._prefix(category='tmp')
221
222    @property
223    def web_prefix(self) -> Path:
224        return self._prefix(category='web')
225
226    @property
227    def prefix(self) -> Path:
228        return self._prefix()
229
230    def _prefix(self, category: str | None = None) -> Path:
231        """
232        Prepare a unique path for the workflow with this name and this input data.
233        """
234        return get_multicohort().analysis_dataset.prefix(category=category) / self.name / self.output_version
235
236    def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path:
237        """
238        Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method
239        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID
240        e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"
241
242        Args:
243            cohort (Cohort): we pull the analysis dataset and name from this Cohort
244            category (str | None): sub-bucket for this project
245
246        Returns:
247            Path
248        """
249        return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.name
250
251    def run(
252        self,
253        stages: list['StageDecorator'] | None = None,
254        wait: bool | None = False,
255    ):
256        """
257        Resolve stages, add and submit Hail Batch jobs.
258        When `run_all_implicit_stages` is set, all required stages that were not defined
259        explicitly would still be executed.
260        """
261        _stages = stages or self._stages
262        if not _stages:
263            raise WorkflowError('No stages added')
264        self.set_stages(_stages)
265
266        if not self.dry_run:
267            get_batch().run(wait=wait)
268
269    @staticmethod
270    def _process_first_last_stages(
271        stages: list['Stage'],
272        graph: nx.DiGraph,
273        first_stages: list[str],
274        last_stages: list[str],
275    ):
276        """
277        Applying first_stages and last_stages config options. Would skip all stages
278        before first_stages, and all stages after last_stages (i.e. descendants and
279        ancestors on the stages DAG.)
280        """
281        stages_d = {s.name: s for s in stages}
282        stage_names = list(stg.name for stg in stages)
283        lower_names = {s.lower() for s in stage_names}
284
285        for param, _stage_list in [
286            ('first_stages', first_stages),
287            ('last_stages', last_stages),
288        ]:
289            for _s_name in _stage_list:
290                if _s_name.lower() not in lower_names:
291                    raise WorkflowError(
292                        f'Value in workflow/{param} "{_s_name}" must be a stage name '
293                        f"or a subset of stages from the available list: "
294                        f'{", ".join(stage_names)}',
295                    )
296
297        if not (last_stages or first_stages):
298            return
299
300        # E.g. if our last_stages is CramQc, MtToEs would still run because it's in
301        # a different branch. So we want to collect all stages after first_stages
302        # and before last_stages in their respective branches, and mark as skipped
303        # everything in other branches.
304        first_stages_keeps: list[str] = first_stages[:]
305        last_stages_keeps: list[str] = last_stages[:]
306
307        for fs in first_stages:
308            for descendant in nx.descendants(graph, fs):
309                if not stages_d[descendant].skipped:
310                    LOGGER.info(
311                        f'Skipping stage {descendant} (precedes {fs} listed in first_stages)',
312                    )
313                    stages_d[descendant].skipped = True
314                for grand_descendant in nx.descendants(graph, descendant):
315                    if not stages_d[grand_descendant].assume_outputs_exist:
316                        LOGGER.info(
317                            f'Not checking expected outputs of not immediately '
318                            f'required stage {grand_descendant} (< {descendant} < {fs})',
319                        )
320                        stages_d[grand_descendant].assume_outputs_exist = True
321
322            for ancestor in nx.ancestors(graph, fs):
323                first_stages_keeps.append(ancestor)
324
325        for ls in last_stages:
326            # ancestors of this last_stage
327            ancestors = nx.ancestors(graph, ls)
328            if any(anc in last_stages for anc in ancestors):
329                # a downstream stage is also in last_stages, so this is not yet
330                # a "real" last stage that we want to run
331                continue
332            for ancestor in ancestors:
333                if stages_d[ancestor].skipped:
334                    continue  # already skipped
335                LOGGER.info(f'Skipping stage {ancestor} (after last {ls})')
336                stages_d[ancestor].skipped = True
337                stages_d[ancestor].assume_outputs_exist = True
338
339            for ancestor in nx.descendants(graph, ls):
340                last_stages_keeps.append(ancestor)
341
342        for _stage in stages:
343            if _stage.name not in last_stages_keeps + first_stages_keeps:
344                _stage.skipped = True
345                _stage.assume_outputs_exist = True
346
347    @staticmethod
348    def _process_only_stages(
349        stages: list['Stage'],
350        graph: nx.DiGraph,
351        only_stages: list[str],
352    ):
353        if not only_stages:
354            return
355
356        stages_d = {s.name: s for s in stages}
357        stage_names = list(stg.name for stg in stages)
358        lower_names = {s.lower() for s in stage_names}
359
360        for s_name in only_stages:
361            if s_name.lower() not in lower_names:
362                raise WorkflowError(
363                    f'Value in workflow/only_stages "{s_name}" must be a stage '
364                    f"name or a subset of stages from the available list: "
365                    f'{", ".join(stage_names)}',
366                )
367
368        # We want to run stages only appearing in only_stages, and check outputs of
369        # imediate predecessor stages, but skip everything else.
370        required_stages: set[str] = set()
371        for os in only_stages:
372            rs = nx.descendants_at_distance(graph, os, 1)
373            required_stages |= set(rs)
374
375        for stage in stages:
376            # Skip stage not in only_stages, and assume outputs exist...
377            if stage.name not in only_stages:
378                stage.skipped = True
379                stage.assume_outputs_exist = True
380
381        # ...unless stage is directly required by any stage in only_stages
382        for stage_name in required_stages:
383            stages_d[stage_name].assume_outputs_exist = False
384
385    def set_stages(
386        self,
387        requested_stages: list['StageDecorator'],
388    ):
389        """
390        Iterate over stages and call their queue_for_cohort(cohort) methods;
391        through that, creates all Hail Batch jobs through Stage.queue_jobs().
392        """
393        # TOML options to configure stages:
394        skip_stages = get_config()['workflow'].get('skip_stages', [])
395        only_stages = get_config()['workflow'].get('only_stages', [])
396        first_stages = get_config()['workflow'].get('first_stages', [])
397        last_stages = get_config()['workflow'].get('last_stages', [])
398
399        # Only allow one of only_stages or first_stages/last_stages as they seem
400        # to be mutually exclusive.
401        if only_stages and (first_stages or last_stages or skip_stages):
402            raise WorkflowError(
403                "Workflow config parameter 'only_stages' is incompatible with "
404                + "'first_stages', 'last_stages' and/or 'skip_stages'",
405            )
406
407        LOGGER.info(
408            f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}',
409        )
410        LOGGER.info('Stages additional configuration:')
411        LOGGER.info(f'  workflow/skip_stages: {skip_stages}')
412        LOGGER.info(f'  workflow/only_stages: {only_stages}')
413        LOGGER.info(f'  workflow/first_stages: {first_stages}')
414        LOGGER.info(f'  workflow/last_stages: {last_stages}')
415
416        # Round 1: initialising stage objects.
417        _stages_d: dict[str, Stage] = {}
418        for cls in requested_stages:
419            if cls.__name__ in _stages_d:
420                continue
421            _stages_d[cls.__name__] = cls()
422
423        # Round 2: depth search to find implicit stages.
424        depth = 0
425        while True:  # might require few iterations to resolve dependencies recursively
426            depth += 1
427            newly_implicitly_added_d = dict()
428            for stg in _stages_d.values():
429                if stg.name in skip_stages:
430                    stg.skipped = True
431                    continue  # not searching deeper
432
433                if only_stages and stg.name not in only_stages:
434                    stg.skipped = True
435
436                # Iterate dependencies:
437                for reqcls in stg.required_stages_classes:
438                    if reqcls.__name__ in _stages_d:  # already added
439                        continue
440                    # Initialising and adding as explicit.
441                    reqstg = reqcls()
442                    newly_implicitly_added_d[reqstg.name] = reqstg
443
444            if newly_implicitly_added_d:
445                LOGGER.info(
446                    f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}',
447                )
448                _stages_d |= newly_implicitly_added_d
449            else:
450                # No new implicit stages added, so can stop the depth-search here
451                break
452
453        # Round 3: set "stage.required_stages" fields to each stage.
454        for stg in _stages_d.values():
455            stg.required_stages = [
456                _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d
457            ]
458
459        # Round 4: determining order of execution.
460        dag_node2nodes = dict()  # building a DAG
461        for stg in _stages_d.values():
462            dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages)
463        dag = nx.DiGraph(dag_node2nodes)
464        try:
465            stage_names = list(reversed(list(nx.topological_sort(dag))))
466        except nx.NetworkXUnfeasible:
467            LOGGER.error('Circular dependencies found between stages')
468            raise
469
470        LOGGER.info(f'Stages in order of execution:\n{stage_names}')
471        stages = [_stages_d[name] for name in stage_names]
472
473        # Round 5: applying workflow options first_stages and last_stages.
474        if first_stages or last_stages:
475            LOGGER.info('Applying workflow/first_stages and workflow/last_stages')
476            self._process_first_last_stages(stages, dag, first_stages, last_stages)
477        elif only_stages:
478            LOGGER.info('Applying workflow/only_stages')
479            self._process_only_stages(stages, dag, only_stages)
480
481        if not (final_set_of_stages := [s.name for s in stages if not s.skipped]):
482            raise WorkflowError('No stages to run')
483
484        LOGGER.info(
485            f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}',
486        )
487
488        required_skipped_stages = [s for s in stages if s.skipped]
489        if required_skipped_stages:
490            LOGGER.info(
491                f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}',
492            )
493
494        # Round 6: actually adding jobs from the stages.
495        if not self.dry_run:
496            inputs = get_multicohort()  # Would communicate with metamist.
497            for i, stg in enumerate(stages):
498                LOGGER.info('*' * 60)
499                LOGGER.info(f'Stage #{i + 1}: {stg}')
500                # pipeline setup is now done in MultiCohort only
501                # the legacy version (input_datasets) is still supported
502                # that will create a MultiCohort with a single Cohort
503                if isinstance(inputs, MultiCohort):
504                    stg.output_by_target = stg.queue_for_multicohort(inputs)
505                else:
506                    raise WorkflowError(f'Unsupported input type: {inputs}')
507                if errors := self._process_stage_errors(stg.output_by_target):
508                    raise WorkflowError(
509                        f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors),
510                    )
511
512                LOGGER.info('')
513
514        else:
515            self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped]
516            LOGGER.info(f'Queued stages: {self.queued_stages}')
517
518    @staticmethod
519    def _process_stage_errors(
520        output_by_target: dict[str, Union['StageOutput', None]],
521    ) -> list[str]:
522        targets_by_error = defaultdict(list)
523        for target, output in output_by_target.items():
524            if output and output.error_msg:
525                targets_by_error[output.error_msg].append(target)
526        return [f'{error}: {", ".join(target_ids)}' for error, target_ids in targets_by_error.items()]

Encapsulates a Hail Batch object, stages, and a cohort of datasets of sequencing groups. Responsible for orchestrating stages.

Workflow( stages: list[Callable[..., cpg_flow.stage.Stage]] | None = None, dry_run: bool | None = None)
168    def __init__(
169        self,
170        stages: list['StageDecorator'] | None = None,
171        dry_run: bool | None = None,
172    ):
173        if _workflow is not None:
174            raise ValueError(
175                'Workflow already initialised. Use get_workflow() to get the instance',
176            )
177
178        self.dry_run = dry_run or get_config(True)['workflow'].get('dry_run')
179
180        # TODO: should the ['dataset'] be a get? should we rename it to analysis dataset?
181        analysis_dataset = get_config(True)['workflow']['dataset']
182        name = get_config()['workflow'].get('name', analysis_dataset)
183        description = get_config()['workflow'].get('description', name)
184        self.name = slugify(name)
185
186        self._output_version: str | None = None
187        if output_version := get_config()['workflow'].get('output_version'):
188            self._output_version = slugify(output_version)
189
190        self.run_timestamp: str = get_config()['workflow'].get('run_timestamp') or timestamp()
191
192        # Description
193        if self._output_version:
194            description += f': output_version={self._output_version}'
195        description += f': run_timestamp={self.run_timestamp}'
196        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
197            description += f' [{sequencing_type}]'
198        if not self.dry_run:
199            if ds_set := set(d.name for d in get_multicohort().get_datasets()):
200                description += ' ' + ', '.join(sorted(ds_set))
201            reset_batch()
202            get_batch().name = description
203
204        self.status_reporter = None
205        if get_config()['workflow'].get('status_reporter') == 'metamist':
206            self.status_reporter = MetamistStatusReporter()
207        self._stages: list[StageDecorator] | None = stages
208        self.queued_stages: list[Stage] = []
dry_run
name
run_timestamp: str
status_reporter
queued_stages: list[cpg_flow.stage.Stage]
output_version: str
210    @property
211    def output_version(self) -> str:
212        return self._output_version or get_multicohort().alignment_inputs_hash()
analysis_prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
214    @property
215    def analysis_prefix(self) -> Path:
216        return self._prefix(category='analysis')
tmp_prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
218    @property
219    def tmp_prefix(self) -> Path:
220        return self._prefix(category='tmp')
web_prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
222    @property
223    def web_prefix(self) -> Path:
224        return self._prefix(category='web')
prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
226    @property
227    def prefix(self) -> Path:
228        return self._prefix()
def cohort_prefix( self, cohort: cpg_flow.targets.cohort.Cohort, category: str | None = None) -> cloudpathlib.cloudpath.CloudPath | pathlib.Path:
236    def cohort_prefix(self, cohort: Cohort, category: str | None = None) -> Path:
237        """
238        Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method
239        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID
240        e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"
241
242        Args:
243            cohort (Cohort): we pull the analysis dataset and name from this Cohort
244            category (str | None): sub-bucket for this project
245
246        Returns:
247            Path
248        """
249        return cohort.analysis_dataset.prefix(category=category) / self.name / cohort.name

Takes a cohort and category as an argument, calls through to the Workflow cohort_prefix method Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID e.g. "gs://cpg-project-main/seqr_loader/COH123", or "gs://cpg-project-main-analysis/seqr_loader/COH123"

Args: cohort (Cohort): we pull the analysis dataset and name from this Cohort category (str | None): sub-bucket for this project

Returns: Path

def run( self, stages: list[Callable[..., cpg_flow.stage.Stage]] | None = None, wait: bool | None = False):
251    def run(
252        self,
253        stages: list['StageDecorator'] | None = None,
254        wait: bool | None = False,
255    ):
256        """
257        Resolve stages, add and submit Hail Batch jobs.
258        When `run_all_implicit_stages` is set, all required stages that were not defined
259        explicitly would still be executed.
260        """
261        _stages = stages or self._stages
262        if not _stages:
263            raise WorkflowError('No stages added')
264        self.set_stages(_stages)
265
266        if not self.dry_run:
267            get_batch().run(wait=wait)

Resolve stages, add and submit Hail Batch jobs. When run_all_implicit_stages is set, all required stages that were not defined explicitly would still be executed.

def set_stages(self, requested_stages: list):
385    def set_stages(
386        self,
387        requested_stages: list['StageDecorator'],
388    ):
389        """
390        Iterate over stages and call their queue_for_cohort(cohort) methods;
391        through that, creates all Hail Batch jobs through Stage.queue_jobs().
392        """
393        # TOML options to configure stages:
394        skip_stages = get_config()['workflow'].get('skip_stages', [])
395        only_stages = get_config()['workflow'].get('only_stages', [])
396        first_stages = get_config()['workflow'].get('first_stages', [])
397        last_stages = get_config()['workflow'].get('last_stages', [])
398
399        # Only allow one of only_stages or first_stages/last_stages as they seem
400        # to be mutually exclusive.
401        if only_stages and (first_stages or last_stages or skip_stages):
402            raise WorkflowError(
403                "Workflow config parameter 'only_stages' is incompatible with "
404                + "'first_stages', 'last_stages' and/or 'skip_stages'",
405            )
406
407        LOGGER.info(
408            f'End stages for the workflow "{self.name}": {[cls.__name__ for cls in requested_stages]}',
409        )
410        LOGGER.info('Stages additional configuration:')
411        LOGGER.info(f'  workflow/skip_stages: {skip_stages}')
412        LOGGER.info(f'  workflow/only_stages: {only_stages}')
413        LOGGER.info(f'  workflow/first_stages: {first_stages}')
414        LOGGER.info(f'  workflow/last_stages: {last_stages}')
415
416        # Round 1: initialising stage objects.
417        _stages_d: dict[str, Stage] = {}
418        for cls in requested_stages:
419            if cls.__name__ in _stages_d:
420                continue
421            _stages_d[cls.__name__] = cls()
422
423        # Round 2: depth search to find implicit stages.
424        depth = 0
425        while True:  # might require few iterations to resolve dependencies recursively
426            depth += 1
427            newly_implicitly_added_d = dict()
428            for stg in _stages_d.values():
429                if stg.name in skip_stages:
430                    stg.skipped = True
431                    continue  # not searching deeper
432
433                if only_stages and stg.name not in only_stages:
434                    stg.skipped = True
435
436                # Iterate dependencies:
437                for reqcls in stg.required_stages_classes:
438                    if reqcls.__name__ in _stages_d:  # already added
439                        continue
440                    # Initialising and adding as explicit.
441                    reqstg = reqcls()
442                    newly_implicitly_added_d[reqstg.name] = reqstg
443
444            if newly_implicitly_added_d:
445                LOGGER.info(
446                    f'Additional implicit stages: {list(newly_implicitly_added_d.keys())}',
447                )
448                _stages_d |= newly_implicitly_added_d
449            else:
450                # No new implicit stages added, so can stop the depth-search here
451                break
452
453        # Round 3: set "stage.required_stages" fields to each stage.
454        for stg in _stages_d.values():
455            stg.required_stages = [
456                _stages_d[cls.__name__] for cls in stg.required_stages_classes if cls.__name__ in _stages_d
457            ]
458
459        # Round 4: determining order of execution.
460        dag_node2nodes = dict()  # building a DAG
461        for stg in _stages_d.values():
462            dag_node2nodes[stg.name] = set(dep.name for dep in stg.required_stages)
463        dag = nx.DiGraph(dag_node2nodes)
464        try:
465            stage_names = list(reversed(list(nx.topological_sort(dag))))
466        except nx.NetworkXUnfeasible:
467            LOGGER.error('Circular dependencies found between stages')
468            raise
469
470        LOGGER.info(f'Stages in order of execution:\n{stage_names}')
471        stages = [_stages_d[name] for name in stage_names]
472
473        # Round 5: applying workflow options first_stages and last_stages.
474        if first_stages or last_stages:
475            LOGGER.info('Applying workflow/first_stages and workflow/last_stages')
476            self._process_first_last_stages(stages, dag, first_stages, last_stages)
477        elif only_stages:
478            LOGGER.info('Applying workflow/only_stages')
479            self._process_only_stages(stages, dag, only_stages)
480
481        if not (final_set_of_stages := [s.name for s in stages if not s.skipped]):
482            raise WorkflowError('No stages to run')
483
484        LOGGER.info(
485            f'Final list of stages after applying stage configuration options:\n{final_set_of_stages}',
486        )
487
488        required_skipped_stages = [s for s in stages if s.skipped]
489        if required_skipped_stages:
490            LOGGER.info(
491                f'Skipped stages: {", ".join(s.name for s in required_skipped_stages)}',
492            )
493
494        # Round 6: actually adding jobs from the stages.
495        if not self.dry_run:
496            inputs = get_multicohort()  # Would communicate with metamist.
497            for i, stg in enumerate(stages):
498                LOGGER.info('*' * 60)
499                LOGGER.info(f'Stage #{i + 1}: {stg}')
500                # pipeline setup is now done in MultiCohort only
501                # the legacy version (input_datasets) is still supported
502                # that will create a MultiCohort with a single Cohort
503                if isinstance(inputs, MultiCohort):
504                    stg.output_by_target = stg.queue_for_multicohort(inputs)
505                else:
506                    raise WorkflowError(f'Unsupported input type: {inputs}')
507                if errors := self._process_stage_errors(stg.output_by_target):
508                    raise WorkflowError(
509                        f'Stage {stg} failed to queue jobs with errors: ' + '\n'.join(errors),
510                    )
511
512                LOGGER.info('')
513
514        else:
515            self.queued_stages = [stg for stg in _stages_d.values() if not stg.skipped]
516            LOGGER.info(f'Queued stages: {self.queued_stages}')

Iterate over stages and call their queue_for_cohort(cohort) methods; through that, creates all Hail Batch jobs through Stage.queue_jobs().