cpg_flow.stage

Provides a Workflow class and a @stage decorator that allow to define workflows in a declarative fashion.

A Stage object is responsible for creating Hail Batch jobs and declaring outputs (files or metamist analysis objects) that are expected to be produced. Each stage acts on a Target, which can be of the following:

* SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample)
* Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care)
* Cohort - a stratification of SGs in this analysis by Metamist CustomCohort
* MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort

A Workflow object plugs stages together by resolving dependencies between different levels accordingly. Stages are defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in main.py, which provides a way to choose a workflow using a CLI argument.

  1"""
  2Provides a `Workflow` class and a `@stage` decorator that allow to define workflows
  3in a declarative fashion.
  4
  5A `Stage` object is responsible for creating Hail Batch jobs and declaring outputs
  6(files or metamist analysis objects) that are expected to be produced. Each stage
  7acts on a `Target`, which can be of the following:
  8
  9    * SequencingGroup - an individual Sequencing Group (e.g. the CRAM of a single sample)
 10    * Dataset - a stratification of SGs in this analysis by Metamist Project (e.g. all SGs in acute-care)
 11    * Cohort - a stratification of SGs in this analysis by Metamist CustomCohort
 12    * MultiCohort - a union of all SGs in this analysis by Metamist CustomCohort
 13
 14A `Workflow` object plugs stages together by resolving dependencies between different levels accordingly. Stages are
 15defined in this package, and chained into Workflows by their inter-Stages dependencies. Workflow names are defined in
 16main.py, which provides a way to choose a workflow using a CLI argument.
 17"""
 18
 19import functools
 20import pathlib
 21from abc import ABC, abstractmethod
 22from collections.abc import Callable, Sequence
 23from typing import Generic, Optional, TypeVar, cast
 24
 25from cloudpathlib import CloudPath
 26
 27from hailtop.batch.job import Job
 28
 29from cpg_flow.targets import Cohort, Dataset, MultiCohort, SequencingGroup, Target
 30from cpg_flow.utils import ExpectedResultT, exists, get_logger
 31from cpg_flow.workflow import Action, WorkflowError, get_workflow, path_walk
 32from cpg_utils import Path
 33from cpg_utils.config import get_config
 34from cpg_utils.hail_batch import get_batch
 35
 36LOGGER = get_logger(__name__)
 37
 38StageDecorator = Callable[..., 'Stage']
 39
 40# Type variable to use with Generic to make sure a Stage subclass always matches the
 41# corresponding Target subclass. We can't just use the Target superclass because
 42# it would violate the Liskov substitution principle (i.e. any Stage subclass would
 43# have to be able to work on any Target subclass).
 44TargetT = TypeVar('TargetT', bound=Target)
 45
 46
 47class StageInputNotFoundError(Exception):
 48    """
 49    Thrown when a stage requests input from another stage
 50    that doesn't exist.
 51    """
 52
 53
 54# noinspection PyShadowingNames
 55class StageOutput:
 56    """
 57    Represents a result of a specific stage, which was run on a specific target.
 58    Can be a file path, or a Hail Batch Resource. Optionally wrapped in a dict.
 59    """
 60
 61    def __init__(
 62        self,
 63        target: Target,
 64        data: ExpectedResultT = None,
 65        jobs: Sequence[Job | None] | Job | None = None,
 66        meta: dict | None = None,
 67        reusable: bool = False,
 68        skipped: bool = False,
 69        error_msg: str | None = None,
 70        stage: Optional['Stage'] = None,
 71    ):
 72        # Converting str into Path objects.
 73        self.data = data
 74        self.stage = stage
 75        self.target = target
 76        _jobs = [jobs] if isinstance(jobs, Job) else (jobs or [])
 77        self.jobs: list[Job] = [j for j in _jobs if j is not None]
 78        self.meta: dict = meta or {}
 79        self.reusable = reusable
 80        self.skipped = skipped
 81        self.error_msg = error_msg
 82
 83    def __repr__(self) -> str:
 84        res = (
 85            f'StageOutput({self.data}'
 86            f' target={self.target}'
 87            f' stage={self.stage}'
 88            + (' [reusable]' if self.reusable else '')
 89            + (' [skipped]' if self.skipped else '')
 90            + (f' [error: {self.error_msg}]' if self.error_msg else '')
 91            + f' meta={self.meta}'
 92            + ')'
 93        )
 94        return res
 95
 96    def _get(self, key=None) -> str | Path:
 97        if self.data is None:
 98            raise ValueError(f'{self.stage}: output data is not available')
 99
100        if key is not None:
101            if not isinstance(self.data, dict):
102                raise ValueError(
103                    f'{self.stage}: {self.data} is not a dictionary, can\'t get "{key}"',
104                )
105            res = cast(dict, self.data)[key]
106        else:
107            res = self.data
108        return res
109
110    def as_str(self, key=None) -> str:
111        """
112        Cast the result to a simple string. Throw an exception when can't cast.
113        `key` is used to extract the value when the result is a dictionary.
114        """
115        res = self._get(key)
116        if not isinstance(res, str):
117            raise ValueError(f'{res} is not a str.')
118        return cast(str, res)
119
120    def as_path(self, key=None) -> Path:
121        """
122        Cast the result to a path object. Throw an exception when can't cast.
123        `key` is used to extract the value when the result is a dictionary.
124        """
125        res = self._get(key)
126        if not isinstance(res, CloudPath | pathlib.Path):
127            raise ValueError(f'{res} is not a path object.')
128
129        return cast(Path, res)
130
131    def as_dict(self) -> dict[str, Path]:
132        """
133        Cast the result to a dictionary, or throw an error if the cast failed.
134        """
135        if not isinstance(self.data, dict):
136            raise ValueError(f'{self.data} is not a dictionary.')
137        return self.data
138
139
140# noinspection PyShadowingNames
141class StageInput:
142    """
143    Represents an input for a stage run. It wraps the outputs of all required upstream
144    stages for corresponding targets (e.g. all GVCFs from a GenotypeSample stage
145    for a JointCalling stage, along with Hail Batch jobs).
146
147    An object of this class is passed to the public `queue_jobs` method of a Stage,
148    and can be used to query dependency files and jobs.
149    """
150
151    def __init__(self, stage: 'Stage'):
152        self.stage = stage
153        self._outputs_by_target_by_stage: dict[str, dict[str, StageOutput | None]] = {}
154
155    def add_other_stage_output(self, output: StageOutput):
156        """
157        Add output from another stage run.
158        """
159        assert output.stage is not None, output
160        if not output.target.active:
161            return
162        if not output.target.get_sequencing_groups():
163            return
164        if not output.data and not output.jobs:
165            return
166        stage_name = output.stage.name
167        target_id = output.target.target_id
168        if stage_name not in self._outputs_by_target_by_stage:
169            self._outputs_by_target_by_stage[stage_name] = dict()
170        self._outputs_by_target_by_stage[stage_name][target_id] = output
171
172    def _each(
173        self,
174        fun: Callable,
175        stage: StageDecorator,
176    ):
177        if stage.__name__ not in [s.name for s in self.stage.required_stages]:
178            raise WorkflowError(
179                f'{self.stage.name}: getting inputs from stage {stage.__name__}, '
180                f'but {stage.__name__} is not listed in required_stages. '
181                f'Consider adding it into the decorator: '
182                f'@stage(required_stages=[{stage.__name__}])',
183            )
184
185        if stage.__name__ not in self._outputs_by_target_by_stage:
186            raise WorkflowError(
187                f'No inputs from {stage.__name__} for {self.stage.name} found '
188                + 'after skipping targets with missing inputs. '
189                + (
190                    'Check the logs if all sequencing groups were missing inputs from previous '
191                    'stages, and consider changing `workflow/first_stage`'
192                    if get_config()['workflow'].get('skip_sgs_with_missing_input')
193                    else ''
194                ),
195            )
196
197        return {
198            trg: fun(result)
199            for trg, result in self._outputs_by_target_by_stage.get(
200                stage.__name__,
201                {},
202            ).items()
203        }
204
205    def as_path_by_target(
206        self,
207        stage: StageDecorator,
208        key: str | None = None,
209    ) -> dict[str, Path]:
210        """
211        Get a single file path result, indexed by target for a specific stage
212        """
213        return self._each(fun=(lambda r: r.as_path(key=key)), stage=stage)
214
215    def as_dict_by_target(self, stage: StageDecorator) -> dict[str, dict[str, Path]]:
216        """
217        Get as a dict of files/resources for a specific stage, indexed by target
218        """
219        return self._each(fun=(lambda r: r.as_dict()), stage=stage)
220
221    def as_path_dict_by_target(
222        self,
223        stage: StageDecorator,
224    ) -> dict[str, dict[str, Path]]:
225        """
226        Get a dict of paths for a specific stage, and indexed by target
227        """
228        return self._each(fun=(lambda r: r.as_path_dict()), stage=stage)
229
230    def _get(
231        self,
232        target: Target,
233        stage: StageDecorator,
234    ):
235        if not self._outputs_by_target_by_stage.get(stage.__name__):
236            raise StageInputNotFoundError(
237                f'Not found output from stage {stage.__name__}, required for stage '
238                f'{self.stage.name}. Is {stage.__name__} in the `required_stages`'
239                f'decorator? Available: {self._outputs_by_target_by_stage}',
240            )
241        if not self._outputs_by_target_by_stage[stage.__name__].get(target.target_id):
242            raise StageInputNotFoundError(
243                f'Not found output for {target} from stage {stage.__name__}, required for stage {self.stage.name}',
244            )
245        return self._outputs_by_target_by_stage[stage.__name__][target.target_id]
246
247    def as_path(
248        self,
249        target: Target,
250        stage: StageDecorator,
251        key: str | None = None,
252    ) -> Path:
253        """
254        Represent as a path to a file, otherwise fail.
255        `stage` can be callable, or a subclass of Stage
256        """
257        res = self._get(target=target, stage=stage)
258        return res.as_path(key)
259
260    def as_str(
261        self,
262        target: Target,
263        stage: StageDecorator,
264        key: str | None = None,
265    ) -> str:
266        """
267        Represent as a simple string, otherwise fail.
268        `stage` can be callable, or a subclass of Stage
269        """
270        res = self._get(target=target, stage=stage)
271        return res.as_str(key)
272
273    def as_dict(self, target: Target, stage: StageDecorator) -> dict[str, Path]:
274        """
275        Get a dict of paths for a specific target and stage
276        """
277        res = self._get(target=target, stage=stage)
278        return res.as_dict()
279
280    def get_jobs(self, target: Target) -> list[Job]:
281        """
282        Get list of jobs that the next stage would depend on.
283        """
284        all_jobs: list[Job] = []
285        target_sequencing_groups = target.get_sequencing_group_ids()
286        for stage_, outputs_by_target in self._outputs_by_target_by_stage.items():
287            for target_, output in outputs_by_target.items():
288                if output:
289                    output_sequencing_groups = output.target.get_sequencing_group_ids()
290                    sequencing_groups_intersect = set(target_sequencing_groups) & set(
291                        output_sequencing_groups,
292                    )
293                    if sequencing_groups_intersect:
294                        for j in output.jobs:
295                            assert j, f'Stage: {stage_}, target: {target_}, output: {output}'
296                        all_jobs.extend(output.jobs)
297        return all_jobs
298
299
300class Stage(Generic[TargetT], ABC):
301    """
302    Abstract class for a workflow stage. Parametrised by specific Target subclass,
303    i.e. SequencingGroupStage(Stage[SequencingGroup]) should only be able to work on SequencingGroup(Target).
304    """
305
306    def __init__(
307        self,
308        name: str,
309        required_stages: list[StageDecorator] | StageDecorator | None = None,
310        analysis_type: str | None = None,
311        analysis_keys: list[str] | None = None,
312        update_analysis_meta: Callable[[str], dict] | None = None,
313        tolerate_missing_output: bool = False,
314        skipped: bool = False,
315        assume_outputs_exist: bool = False,
316        forced: bool = False,
317    ):
318        self._name = name
319        self.required_stages_classes: list[StageDecorator] = []
320        if required_stages:
321            if isinstance(required_stages, list):
322                self.required_stages_classes.extend(required_stages)
323            else:
324                self.required_stages_classes.append(required_stages)
325
326        # Dependencies. Populated in workflow.run(), after we know all stages.
327        self.required_stages: list[Stage] = []
328
329        self.status_reporter = get_workflow().status_reporter
330        # If `analysis_type` is defined, it will be used to create/update Analysis
331        # entries in Metamist.
332        self.analysis_type = analysis_type
333        # If `analysis_keys` are defined, it will be used to extract the value for
334        # `Analysis.output` if the Stage.expected_outputs() returns a dict.
335        self.analysis_keys = analysis_keys
336        # if `update_analysis_meta` is defined, it is called on the `Analysis.output`
337        # field, and result is merged into the `Analysis.meta` dictionary.
338        self.update_analysis_meta = update_analysis_meta
339
340        self.tolerate_missing_output = tolerate_missing_output
341
342        # Populated with the return value of `add_to_the_workflow()`
343        self.output_by_target: dict[str, StageOutput | None] = dict()
344
345        self.skipped = skipped
346        self.forced = forced or self.name in get_config()['workflow'].get(
347            'force_stages',
348            [],
349        )
350        self.assume_outputs_exist = assume_outputs_exist
351
352    @property
353    def tmp_prefix(self):
354        return get_workflow().tmp_prefix / self.name
355
356    @property
357    def web_prefix(self) -> Path:
358        return get_workflow().web_prefix / self.name
359
360    @property
361    def prefix(self) -> Path:
362        return get_workflow().prefix / self.name
363
364    @property
365    def analysis_prefix(self) -> Path:
366        return get_workflow().analysis_prefix / self.name
367
368    def get_stage_cohort_prefix(
369        self,
370        cohort: Cohort,
371        category: str | None = None,
372    ) -> Path:
373        """
374        Takes a cohort as an argument, calls through to the Workflow cohort_prefix method
375        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID / STAGE_NAME
376        e.g. "gs://cpg-project-main/seqr_loader/COH123/MyStage"
377
378        Args:
379            cohort (Cohort): we pull the analysis dataset and name from this Cohort
380            category (str | none): main, tmp, test, analysis, web
381
382        Returns:
383            Path
384        """
385        return get_workflow().cohort_prefix(cohort, category=category) / self.name
386
387    def __str__(self):
388        res = f'{self._name}'
389        if self.skipped:
390            res += ' [skipped]'
391        if self.forced:
392            res += ' [forced]'
393        if self.assume_outputs_exist:
394            res += ' [assume_outputs_exist]'
395        if self.required_stages:
396            res += f' <- [{", ".join([s.name for s in self.required_stages])}]'
397        return res
398
399    @property
400    def name(self) -> str:
401        """
402        Stage name (unique and descriptive stage)
403        """
404        return self._name
405
406    @abstractmethod
407    def queue_jobs(self, target: TargetT, inputs: StageInput) -> StageOutput | None:
408        """
409        Adds Hail Batch jobs that process `target`.
410        Assumes that all the household work is done: checking missing inputs
411        from required stages, checking for possible reuse of existing outputs.
412        """
413
414    @abstractmethod
415    def expected_outputs(self, target: TargetT) -> ExpectedResultT:
416        """
417        Get path(s) to files that the stage is expected to generate for a `target`.
418        Used within in `queue_jobs()` to pass paths to outputs to job commands,
419        as well as by the workflow to check if the stage's expected outputs already
420        exist and can be reused.
421
422        Can be a str, a Path object, or a dictionary of str/Path objects.
423        """
424
425    # TODO: remove this method
426    def deprecated_queue_for_cohort(
427        self,
428        cohort: Cohort,
429    ) -> dict[str, StageOutput | None]:
430        """
431        Queues jobs for each corresponding target, defined by Stage subclass.
432        Returns a dictionary of `StageOutput` objects indexed by target unique_id.
433        unused, ready for deletion
434        """
435        return {}
436
437    @abstractmethod
438    def queue_for_multicohort(
439        self,
440        multicohort: MultiCohort,
441    ) -> dict[str, StageOutput | None]:
442        """
443        Queues jobs for each corresponding target, defined by Stage subclass.
444
445        Returns a dictionary of `StageOutput` objects indexed by target unique_id.
446        """
447
448    def _make_inputs(self) -> StageInput:
449        """
450        Collects outputs from all dependencies and create input for this stage
451        """
452        inputs = StageInput(self)
453        for prev_stage in self.required_stages:
454            for _, stage_output in prev_stage.output_by_target.items():
455                if stage_output:
456                    inputs.add_other_stage_output(stage_output)
457        return inputs
458
459    def make_outputs(
460        self,
461        target: Target,
462        data: ExpectedResultT = None,  # TODO: ExpectedResultT is probably too broad, our code only really support dict
463        jobs: Sequence[Job | None] | Job | None = None,
464        meta: dict | None = None,
465        reusable: bool = False,
466        skipped: bool = False,
467        error_msg: str | None = None,
468    ) -> StageOutput:
469        """
470        Create StageOutput for this stage.
471        """
472        return StageOutput(
473            target=target,
474            data=data,
475            jobs=jobs,
476            meta=meta,
477            reusable=reusable,
478            skipped=skipped,
479            error_msg=error_msg,
480            stage=self,
481        )
482
483    def _queue_jobs_with_checks(
484        self,
485        target: TargetT,
486        action: Action | None = None,
487    ) -> StageOutput | None:
488        """
489        Checks what to do with target, and either queue jobs, or skip/reuse results.
490        """
491        if not action:
492            action = self._get_action(target)
493
494        inputs = self._make_inputs()
495        expected_out = self.expected_outputs(target)
496
497        if action == Action.QUEUE:
498            outputs = self.queue_jobs(target, inputs)
499        elif action == Action.REUSE:
500            outputs = self.make_outputs(
501                target=target,
502                data=expected_out,
503                reusable=True,
504            )
505        else:  # Action.SKIP
506            outputs = None
507
508        if not outputs:
509            return None
510
511        outputs.stage = self
512        outputs.meta |= self.get_job_attrs(target)
513
514        for output_job in outputs.jobs:
515            if output_job:
516                for input_job in inputs.get_jobs(target):
517                    assert input_job, f'Input dependency job for stage: {self}, target: {target}'
518                    output_job.depends_on(input_job)
519
520        if outputs.error_msg:
521            return outputs
522
523        # Adding status reporter jobs
524        if self.analysis_type and self.status_reporter and action == Action.QUEUE and outputs.data:
525            analysis_outputs: list[str | Path] = []
526            if isinstance(outputs.data, dict):
527                if not self.analysis_keys:
528                    raise WorkflowError(
529                        f'Cannot create Analysis: `analysis_keys` '
530                        f'must be set with the @stage decorator to select value from '
531                        f'the expected_outputs dict: {outputs.data}',
532                    )
533                if not all(key in outputs.data for key in self.analysis_keys):
534                    raise WorkflowError(
535                        f'Cannot create Analysis for stage {self.name}: `analysis_keys` '
536                        f'"{self.analysis_keys}" is not a subset of the expected_outputs '
537                        f'keys {outputs.data.keys()}',
538                    )
539
540                for analysis_key in self.analysis_keys:
541                    analysis_outputs.append(outputs.data[analysis_key])
542
543            else:
544                analysis_outputs.append(outputs.data)
545
546            project_name = None
547            if isinstance(target, SequencingGroup):
548                project_name = target.dataset.name
549            elif isinstance(target, Dataset):
550                project_name = target.name
551            elif isinstance(target, Cohort | MultiCohort):
552                project_name = target.analysis_dataset.name
553
554            assert isinstance(project_name, str)
555
556            # bump name to include `-test`
557            if get_config()['workflow']['access_level'] == 'test' and 'test' not in project_name:
558                project_name = f'{project_name}-test'
559
560            for analysis_output in analysis_outputs:
561                if not outputs.jobs:
562                    continue
563
564                assert isinstance(
565                    analysis_output,
566                    str | Path,
567                ), f'{analysis_output} should be a str or Path object'
568                if outputs.meta is None:
569                    outputs.meta = {}
570
571                self.status_reporter.create_analysis(
572                    b=get_batch(),
573                    output=str(analysis_output),
574                    analysis_type=self.analysis_type,
575                    target=target,
576                    jobs=outputs.jobs,
577                    job_attr=self.get_job_attrs(target) | {'stage': self.name, 'tool': 'metamist'},
578                    meta=outputs.meta,
579                    update_analysis_meta=self.update_analysis_meta,
580                    tolerate_missing_output=self.tolerate_missing_output,
581                    project_name=project_name,
582                )
583
584        return outputs
585
586    def _get_action(self, target: TargetT) -> Action:
587        """
588        Based on stage parameters and expected outputs existence, determines what
589        to do with the target: queue, skip or reuse, etc...
590        """
591        if target.forced and not self.skipped:
592            LOGGER.info(f'{self.name}: {target} [QUEUE] (target is forced)')
593            return Action.QUEUE
594
595        if (d := get_config()['workflow'].get('skip_stages_for_sgs')) and self.name in d:
596            skip_targets = d[self.name]
597            if target.target_id in skip_targets:
598                LOGGER.info(
599                    f'{self.name}: {target} [SKIP] (is in workflow/skip_stages_for_sgs)',
600                )
601                return Action.SKIP
602
603        expected_out = self.expected_outputs(target)
604        reusable, first_missing_path = self._is_reusable(expected_out)
605
606        if self.skipped:
607            if reusable and not first_missing_path:
608                LOGGER.debug(
609                    f'{self.name}: {target} [REUSE] (stage skipped, and outputs exist)',
610                )
611                return Action.REUSE
612            if get_config()['workflow'].get('skip_sgs_with_missing_input'):
613                LOGGER.warning(
614                    f'{self.name}: {target} [SKIP] (stage is required, '
615                    f'but is marked as "skipped", '
616                    f'workflow/skip_sgs_with_missing_input=true '
617                    f'and some expected outputs for the target do not exist: '
618                    f'{first_missing_path}',
619                )
620                # `workflow/skip_sgs_with_missing_input` means that we can ignore
621                # sgs/datasets that have missing results from skipped stages.
622                # This is our case, so indicating that this sg/dataset should
623                # be ignored:
624                target.active = False
625                return Action.SKIP
626            if self.name in get_config()['workflow'].get(
627                'allow_missing_outputs_for_stages',
628                [],
629            ):
630                LOGGER.info(
631                    f'{self.name}: {target} [REUSE] (stage is skipped, some outputs are'
632                    f'missing, but stage is listed in '
633                    f'workflow/allow_missing_outputs_for_stages)',
634                )
635                return Action.REUSE
636            raise WorkflowError(
637                f'{self.name}: stage is required, but is skipped, and '
638                f'the following expected outputs for target {target} do not exist: '
639                f'{first_missing_path}',
640            )
641
642        if reusable and not first_missing_path:
643            if target.forced:
644                LOGGER.info(
645                    f'{self.name}: {target} [QUEUE] (can reuse, but forcing the target to rerun this stage)',
646                )
647                return Action.QUEUE
648            if self.forced:
649                LOGGER.info(
650                    f'{self.name}: {target} [QUEUE] (can reuse, but forcing the stage to rerun)',
651                )
652                return Action.QUEUE
653            LOGGER.info(
654                f'{self.name}: {target} [REUSE] (expected outputs exist: {expected_out})',
655            )
656            return Action.REUSE
657
658        LOGGER.info(f'{self.name}: {target} [QUEUE]')
659
660        return Action.QUEUE
661
662    def _is_reusable(self, expected_out: ExpectedResultT) -> tuple[bool, Path | None]:
663        """
664        Checks if the outputs of prior stages already exist, and can be reused
665        Args:
666            expected_out (ExpectedResultT): expected outputs of a stage
667
668        Returns:
669            tuple[bool, Path | None]:
670                bool: True if the outputs can be reused, False otherwise
671                Path | None: first missing path, if any
672        """
673        if self.assume_outputs_exist:
674            LOGGER.debug(f'Assuming outputs exist. Expected output is {expected_out}')
675            return True, None
676
677        if not expected_out:
678            # Marking is reusable. If the stage does not naturally produce any outputs,
679            # it would still need to create some flag file.
680            LOGGER.debug('No expected outputs, assuming outputs exist')
681            return True, None
682
683        if get_config()['workflow'].get('check_expected_outputs'):
684            paths = path_walk(expected_out)
685            LOGGER.info(
686                f'Checking if {paths} from expected output {expected_out} exist',
687            )
688            if not paths:
689                LOGGER.info(f'{expected_out} is not reusable. No paths found.')
690                return False, None
691
692            if first_missing_path := next((p for p in paths if not exists(p)), None):
693                LOGGER.info(
694                    f'{expected_out} is not reusable, {first_missing_path} is missing',
695                )
696                return False, first_missing_path
697
698            return True, None
699        if self.skipped:
700            # Do not check the files' existence, trust they exist.
701            # note that for skipped stages, we automatically assume outputs exist
702            return True, None
703        # Do not check the files' existence, assume they don't exist:
704        return False, None
705
706    def get_job_attrs(self, target: TargetT | None = None) -> dict[str, str]:
707        """
708        Create Hail Batch Job attributes dictionary
709        """
710        job_attrs = dict(stage=self.name)
711        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
712            job_attrs['sequencing_type'] = sequencing_type
713        if target:
714            job_attrs |= target.get_job_attrs()
715        return job_attrs
716
717
718def stage(
719    cls: type['Stage'] | None = None,
720    *,
721    analysis_type: str | None = None,
722    analysis_keys: list[str | Path] | None = None,
723    update_analysis_meta: Callable[[str], dict] | None = None,
724    tolerate_missing_output: bool = False,
725    required_stages: list[StageDecorator] | StageDecorator | None = None,
726    skipped: bool = False,
727    assume_outputs_exist: bool = False,
728    forced: bool = False,
729) -> StageDecorator | Callable[..., StageDecorator]:
730    """
731    Implements a standard class decorator pattern with optional arguments.
732    The goal is to allow declaring workflow stages without requiring to implement
733    a constructor method. E.g.
734
735    @stage(required_stages=[Align])
736    class GenotypeSample(SequencingGroupStage):
737        def expected_outputs(self, sequencing_group: SequencingGroup):
738            ...
739        def queue_jobs(self, sequencing_group: SequencingGroup, inputs: StageInput) -> StageOutput:
740            ...
741
742    @analysis_type: if defined, will be used to create/update `Analysis` entries
743        using the status reporter.
744    @analysis_keys: is defined, will be used to extract the value for `Analysis.output`
745        if the Stage.expected_outputs() returns a dict.
746    @update_analysis_meta: if defined, this function is called on the `Analysis.output`
747        field, and returns a dictionary to be merged into the `Analysis.meta`
748    @tolerate_missing_output: if True, when registering the output of this stage,
749        allow for the output file to be missing (only relevant for metamist entry)
750    @required_stages: list of other stage classes that are required prerequisites
751        for this stage. Outputs of those stages will be passed to
752        `Stage.queue_jobs(... , inputs)` as `inputs`, and all required
753        dependencies between Hail Batch jobs will be set automatically as well.
754    @skipped: always skip this stage.
755    @assume_outputs_exist: assume expected outputs of this stage always exist.
756    @forced: always force run this stage, regardless of the outputs' existence.
757    """
758
759    def decorator_stage(_cls) -> StageDecorator:
760        """Implements decorator."""
761
762        @functools.wraps(_cls)
763        def wrapper_stage() -> Stage:
764            """Decorator helper function."""
765            return _cls(
766                name=_cls.__name__,
767                required_stages=required_stages,
768                analysis_type=analysis_type,
769                analysis_keys=analysis_keys,
770                update_analysis_meta=update_analysis_meta,
771                skipped=skipped,
772                assume_outputs_exist=assume_outputs_exist,
773                forced=forced,
774                tolerate_missing_output=tolerate_missing_output,
775            )
776
777        return wrapper_stage
778
779    if cls is None:
780        return decorator_stage
781    return decorator_stage(cls)
782
783
784class SequencingGroupStage(Stage[SequencingGroup], ABC):
785    """
786    Sequencing Group level stage.
787    """
788
789    @abstractmethod
790    def expected_outputs(self, sequencing_group: SequencingGroup) -> ExpectedResultT:
791        """
792        Override to declare expected output paths.
793        """
794
795    @abstractmethod
796    def queue_jobs(
797        self,
798        sequencing_group: SequencingGroup,
799        inputs: StageInput,
800    ) -> StageOutput | None:
801        """
802        Override to add Hail Batch jobs.
803        """
804
805    def queue_for_multicohort(
806        self,
807        multicohort: MultiCohort,
808    ) -> dict[str, StageOutput | None]:
809        """
810        Plug the stage into the workflow.
811        """
812        output_by_target: dict[str, StageOutput | None] = dict()
813        if not (active_sgs := multicohort.get_sequencing_groups()):
814            all_sgs = len(multicohort.get_sequencing_groups(only_active=False))
815            LOGGER.warning(
816                f'{len(active_sgs)}/{all_sgs} usable (active=True) SGs found in the multicohort. '
817                'Check that input_cohorts` or `input_datasets` are provided and not skipped',
818            )
819            return output_by_target
820
821        # evaluate_stuff en masse
822        for sequencing_group in active_sgs:
823            action = self._get_action(sequencing_group)
824            output_by_target[sequencing_group.target_id] = self._queue_jobs_with_checks(
825                sequencing_group,
826                action,
827            )
828        return output_by_target
829
830
831class DatasetStage(Stage, ABC):
832    """
833    Dataset-level stage
834    """
835
836    @abstractmethod
837    def expected_outputs(self, dataset: Dataset) -> ExpectedResultT:
838        """
839        Override to declare expected output paths.
840        """
841
842    @abstractmethod
843    def queue_jobs(self, dataset: Dataset, inputs: StageInput) -> StageOutput | None:
844        """
845        Override to add Hail Batch jobs.
846        """
847
848    def queue_for_multicohort(
849        self,
850        multicohort: MultiCohort,
851    ) -> dict[str, StageOutput | None]:
852        """
853        Plug the stage into the workflow.
854        """
855        output_by_target: dict[str, StageOutput | None] = dict()
856        # iterate directly over the datasets in this multicohort
857        for dataset_i, dataset in enumerate(multicohort.get_datasets()):
858            action = self._get_action(dataset)
859            LOGGER.info(f'{self.name}: #{dataset_i + 1}/{dataset} [{action.name}]')
860            output_by_target[dataset.target_id] = self._queue_jobs_with_checks(
861                dataset,
862                action,
863            )
864        return output_by_target
865
866
867class CohortStage(Stage, ABC):
868    """
869    Cohort-level stage (all datasets of a workflow run).
870    """
871
872    @abstractmethod
873    def expected_outputs(self, cohort: Cohort) -> ExpectedResultT:
874        """
875        Override to declare expected output paths.
876        """
877
878    @abstractmethod
879    def queue_jobs(self, cohort: Cohort, inputs: StageInput) -> StageOutput | None:
880        """
881        Override to add Hail Batch jobs.
882        """
883
884    def queue_for_multicohort(
885        self,
886        multicohort: MultiCohort,
887    ) -> dict[str, StageOutput | None]:
888        """
889        Plug the stage into the workflow.
890        """
891        output_by_target: dict[str, StageOutput | None] = dict()
892        for cohort in multicohort.get_cohorts():
893            action = self._get_action(cohort)
894            LOGGER.info(f'{self.name}: {cohort} [{action.name}]')
895            output_by_target[cohort.target_id] = self._queue_jobs_with_checks(
896                cohort,
897                action,
898            )
899        return output_by_target
900
901
902class MultiCohortStage(Stage, ABC):
903    """
904    MultiCohort-level stage (all datasets of a workflow run).
905    """
906
907    @abstractmethod
908    def expected_outputs(self, multicohort: MultiCohort) -> ExpectedResultT:
909        """
910        Override to declare expected output paths.
911        """
912
913    @abstractmethod
914    def queue_jobs(
915        self,
916        multicohort: MultiCohort,
917        inputs: StageInput,
918    ) -> StageOutput | None:
919        """
920        Override to add Hail Batch jobs.
921        """
922
923    def queue_for_multicohort(
924        self,
925        multicohort: MultiCohort,
926    ) -> dict[str, StageOutput | None]:
927        """
928        Plug the stage into the workflow.
929        """
930        output_by_target: dict[str, StageOutput | None] = dict()
931        action = self._get_action(multicohort)
932        LOGGER.info(f'{self.name}: {multicohort} [{action.name}]')
933        output_by_target[multicohort.target_id] = self._queue_jobs_with_checks(
934            multicohort,
935            action,
936        )
937        return output_by_target
LOGGER = <Logger cpg_flow.stage (INFO)>
StageDecorator = collections.abc.Callable[..., 'Stage']
class StageInputNotFoundError(builtins.Exception):
48class StageInputNotFoundError(Exception):
49    """
50    Thrown when a stage requests input from another stage
51    that doesn't exist.
52    """

Thrown when a stage requests input from another stage that doesn't exist.

class StageOutput:
 56class StageOutput:
 57    """
 58    Represents a result of a specific stage, which was run on a specific target.
 59    Can be a file path, or a Hail Batch Resource. Optionally wrapped in a dict.
 60    """
 61
 62    def __init__(
 63        self,
 64        target: Target,
 65        data: ExpectedResultT = None,
 66        jobs: Sequence[Job | None] | Job | None = None,
 67        meta: dict | None = None,
 68        reusable: bool = False,
 69        skipped: bool = False,
 70        error_msg: str | None = None,
 71        stage: Optional['Stage'] = None,
 72    ):
 73        # Converting str into Path objects.
 74        self.data = data
 75        self.stage = stage
 76        self.target = target
 77        _jobs = [jobs] if isinstance(jobs, Job) else (jobs or [])
 78        self.jobs: list[Job] = [j for j in _jobs if j is not None]
 79        self.meta: dict = meta or {}
 80        self.reusable = reusable
 81        self.skipped = skipped
 82        self.error_msg = error_msg
 83
 84    def __repr__(self) -> str:
 85        res = (
 86            f'StageOutput({self.data}'
 87            f' target={self.target}'
 88            f' stage={self.stage}'
 89            + (' [reusable]' if self.reusable else '')
 90            + (' [skipped]' if self.skipped else '')
 91            + (f' [error: {self.error_msg}]' if self.error_msg else '')
 92            + f' meta={self.meta}'
 93            + ')'
 94        )
 95        return res
 96
 97    def _get(self, key=None) -> str | Path:
 98        if self.data is None:
 99            raise ValueError(f'{self.stage}: output data is not available')
100
101        if key is not None:
102            if not isinstance(self.data, dict):
103                raise ValueError(
104                    f'{self.stage}: {self.data} is not a dictionary, can\'t get "{key}"',
105                )
106            res = cast(dict, self.data)[key]
107        else:
108            res = self.data
109        return res
110
111    def as_str(self, key=None) -> str:
112        """
113        Cast the result to a simple string. Throw an exception when can't cast.
114        `key` is used to extract the value when the result is a dictionary.
115        """
116        res = self._get(key)
117        if not isinstance(res, str):
118            raise ValueError(f'{res} is not a str.')
119        return cast(str, res)
120
121    def as_path(self, key=None) -> Path:
122        """
123        Cast the result to a path object. Throw an exception when can't cast.
124        `key` is used to extract the value when the result is a dictionary.
125        """
126        res = self._get(key)
127        if not isinstance(res, CloudPath | pathlib.Path):
128            raise ValueError(f'{res} is not a path object.')
129
130        return cast(Path, res)
131
132    def as_dict(self) -> dict[str, Path]:
133        """
134        Cast the result to a dictionary, or throw an error if the cast failed.
135        """
136        if not isinstance(self.data, dict):
137            raise ValueError(f'{self.data} is not a dictionary.')
138        return self.data

Represents a result of a specific stage, which was run on a specific target. Can be a file path, or a Hail Batch Resource. Optionally wrapped in a dict.

StageOutput( target: cpg_flow.targets.target.Target, data: Union[cloudpathlib.cloudpath.CloudPath, pathlib.Path, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path], dict[str, str], str, NoneType] = None, jobs: Sequence[hailtop.batch.job.Job | None] | hailtop.batch.job.Job | None = None, meta: dict | None = None, reusable: bool = False, skipped: bool = False, error_msg: str | None = None, stage: Optional[Stage] = None)
62    def __init__(
63        self,
64        target: Target,
65        data: ExpectedResultT = None,
66        jobs: Sequence[Job | None] | Job | None = None,
67        meta: dict | None = None,
68        reusable: bool = False,
69        skipped: bool = False,
70        error_msg: str | None = None,
71        stage: Optional['Stage'] = None,
72    ):
73        # Converting str into Path objects.
74        self.data = data
75        self.stage = stage
76        self.target = target
77        _jobs = [jobs] if isinstance(jobs, Job) else (jobs or [])
78        self.jobs: list[Job] = [j for j in _jobs if j is not None]
79        self.meta: dict = meta or {}
80        self.reusable = reusable
81        self.skipped = skipped
82        self.error_msg = error_msg
data
stage
target
jobs: list[hailtop.batch.job.Job]
meta: dict
reusable
skipped
error_msg
def as_str(self, key=None) -> str:
111    def as_str(self, key=None) -> str:
112        """
113        Cast the result to a simple string. Throw an exception when can't cast.
114        `key` is used to extract the value when the result is a dictionary.
115        """
116        res = self._get(key)
117        if not isinstance(res, str):
118            raise ValueError(f'{res} is not a str.')
119        return cast(str, res)

Cast the result to a simple string. Throw an exception when can't cast. key is used to extract the value when the result is a dictionary.

def as_path(self, key=None) -> cloudpathlib.cloudpath.CloudPath | pathlib.Path:
121    def as_path(self, key=None) -> Path:
122        """
123        Cast the result to a path object. Throw an exception when can't cast.
124        `key` is used to extract the value when the result is a dictionary.
125        """
126        res = self._get(key)
127        if not isinstance(res, CloudPath | pathlib.Path):
128            raise ValueError(f'{res} is not a path object.')
129
130        return cast(Path, res)

Cast the result to a path object. Throw an exception when can't cast. key is used to extract the value when the result is a dictionary.

def as_dict(self) -> dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path]:
132    def as_dict(self) -> dict[str, Path]:
133        """
134        Cast the result to a dictionary, or throw an error if the cast failed.
135        """
136        if not isinstance(self.data, dict):
137            raise ValueError(f'{self.data} is not a dictionary.')
138        return self.data

Cast the result to a dictionary, or throw an error if the cast failed.

class StageInput:
142class StageInput:
143    """
144    Represents an input for a stage run. It wraps the outputs of all required upstream
145    stages for corresponding targets (e.g. all GVCFs from a GenotypeSample stage
146    for a JointCalling stage, along with Hail Batch jobs).
147
148    An object of this class is passed to the public `queue_jobs` method of a Stage,
149    and can be used to query dependency files and jobs.
150    """
151
152    def __init__(self, stage: 'Stage'):
153        self.stage = stage
154        self._outputs_by_target_by_stage: dict[str, dict[str, StageOutput | None]] = {}
155
156    def add_other_stage_output(self, output: StageOutput):
157        """
158        Add output from another stage run.
159        """
160        assert output.stage is not None, output
161        if not output.target.active:
162            return
163        if not output.target.get_sequencing_groups():
164            return
165        if not output.data and not output.jobs:
166            return
167        stage_name = output.stage.name
168        target_id = output.target.target_id
169        if stage_name not in self._outputs_by_target_by_stage:
170            self._outputs_by_target_by_stage[stage_name] = dict()
171        self._outputs_by_target_by_stage[stage_name][target_id] = output
172
173    def _each(
174        self,
175        fun: Callable,
176        stage: StageDecorator,
177    ):
178        if stage.__name__ not in [s.name for s in self.stage.required_stages]:
179            raise WorkflowError(
180                f'{self.stage.name}: getting inputs from stage {stage.__name__}, '
181                f'but {stage.__name__} is not listed in required_stages. '
182                f'Consider adding it into the decorator: '
183                f'@stage(required_stages=[{stage.__name__}])',
184            )
185
186        if stage.__name__ not in self._outputs_by_target_by_stage:
187            raise WorkflowError(
188                f'No inputs from {stage.__name__} for {self.stage.name} found '
189                + 'after skipping targets with missing inputs. '
190                + (
191                    'Check the logs if all sequencing groups were missing inputs from previous '
192                    'stages, and consider changing `workflow/first_stage`'
193                    if get_config()['workflow'].get('skip_sgs_with_missing_input')
194                    else ''
195                ),
196            )
197
198        return {
199            trg: fun(result)
200            for trg, result in self._outputs_by_target_by_stage.get(
201                stage.__name__,
202                {},
203            ).items()
204        }
205
206    def as_path_by_target(
207        self,
208        stage: StageDecorator,
209        key: str | None = None,
210    ) -> dict[str, Path]:
211        """
212        Get a single file path result, indexed by target for a specific stage
213        """
214        return self._each(fun=(lambda r: r.as_path(key=key)), stage=stage)
215
216    def as_dict_by_target(self, stage: StageDecorator) -> dict[str, dict[str, Path]]:
217        """
218        Get as a dict of files/resources for a specific stage, indexed by target
219        """
220        return self._each(fun=(lambda r: r.as_dict()), stage=stage)
221
222    def as_path_dict_by_target(
223        self,
224        stage: StageDecorator,
225    ) -> dict[str, dict[str, Path]]:
226        """
227        Get a dict of paths for a specific stage, and indexed by target
228        """
229        return self._each(fun=(lambda r: r.as_path_dict()), stage=stage)
230
231    def _get(
232        self,
233        target: Target,
234        stage: StageDecorator,
235    ):
236        if not self._outputs_by_target_by_stage.get(stage.__name__):
237            raise StageInputNotFoundError(
238                f'Not found output from stage {stage.__name__}, required for stage '
239                f'{self.stage.name}. Is {stage.__name__} in the `required_stages`'
240                f'decorator? Available: {self._outputs_by_target_by_stage}',
241            )
242        if not self._outputs_by_target_by_stage[stage.__name__].get(target.target_id):
243            raise StageInputNotFoundError(
244                f'Not found output for {target} from stage {stage.__name__}, required for stage {self.stage.name}',
245            )
246        return self._outputs_by_target_by_stage[stage.__name__][target.target_id]
247
248    def as_path(
249        self,
250        target: Target,
251        stage: StageDecorator,
252        key: str | None = None,
253    ) -> Path:
254        """
255        Represent as a path to a file, otherwise fail.
256        `stage` can be callable, or a subclass of Stage
257        """
258        res = self._get(target=target, stage=stage)
259        return res.as_path(key)
260
261    def as_str(
262        self,
263        target: Target,
264        stage: StageDecorator,
265        key: str | None = None,
266    ) -> str:
267        """
268        Represent as a simple string, otherwise fail.
269        `stage` can be callable, or a subclass of Stage
270        """
271        res = self._get(target=target, stage=stage)
272        return res.as_str(key)
273
274    def as_dict(self, target: Target, stage: StageDecorator) -> dict[str, Path]:
275        """
276        Get a dict of paths for a specific target and stage
277        """
278        res = self._get(target=target, stage=stage)
279        return res.as_dict()
280
281    def get_jobs(self, target: Target) -> list[Job]:
282        """
283        Get list of jobs that the next stage would depend on.
284        """
285        all_jobs: list[Job] = []
286        target_sequencing_groups = target.get_sequencing_group_ids()
287        for stage_, outputs_by_target in self._outputs_by_target_by_stage.items():
288            for target_, output in outputs_by_target.items():
289                if output:
290                    output_sequencing_groups = output.target.get_sequencing_group_ids()
291                    sequencing_groups_intersect = set(target_sequencing_groups) & set(
292                        output_sequencing_groups,
293                    )
294                    if sequencing_groups_intersect:
295                        for j in output.jobs:
296                            assert j, f'Stage: {stage_}, target: {target_}, output: {output}'
297                        all_jobs.extend(output.jobs)
298        return all_jobs

Represents an input for a stage run. It wraps the outputs of all required upstream stages for corresponding targets (e.g. all GVCFs from a GenotypeSample stage for a JointCalling stage, along with Hail Batch jobs).

An object of this class is passed to the public queue_jobs method of a Stage, and can be used to query dependency files and jobs.

StageInput(stage: Stage)
152    def __init__(self, stage: 'Stage'):
153        self.stage = stage
154        self._outputs_by_target_by_stage: dict[str, dict[str, StageOutput | None]] = {}
stage
def add_other_stage_output(self, output: StageOutput):
156    def add_other_stage_output(self, output: StageOutput):
157        """
158        Add output from another stage run.
159        """
160        assert output.stage is not None, output
161        if not output.target.active:
162            return
163        if not output.target.get_sequencing_groups():
164            return
165        if not output.data and not output.jobs:
166            return
167        stage_name = output.stage.name
168        target_id = output.target.target_id
169        if stage_name not in self._outputs_by_target_by_stage:
170            self._outputs_by_target_by_stage[stage_name] = dict()
171        self._outputs_by_target_by_stage[stage_name][target_id] = output

Add output from another stage run.

def as_path_by_target( self, stage: Callable, key: str | None = None) -> dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path]:
206    def as_path_by_target(
207        self,
208        stage: StageDecorator,
209        key: str | None = None,
210    ) -> dict[str, Path]:
211        """
212        Get a single file path result, indexed by target for a specific stage
213        """
214        return self._each(fun=(lambda r: r.as_path(key=key)), stage=stage)

Get a single file path result, indexed by target for a specific stage

def as_dict_by_target( self, stage: Callable) -> dict[str, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path]]:
216    def as_dict_by_target(self, stage: StageDecorator) -> dict[str, dict[str, Path]]:
217        """
218        Get as a dict of files/resources for a specific stage, indexed by target
219        """
220        return self._each(fun=(lambda r: r.as_dict()), stage=stage)

Get as a dict of files/resources for a specific stage, indexed by target

def as_path_dict_by_target( self, stage: Callable) -> dict[str, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path]]:
222    def as_path_dict_by_target(
223        self,
224        stage: StageDecorator,
225    ) -> dict[str, dict[str, Path]]:
226        """
227        Get a dict of paths for a specific stage, and indexed by target
228        """
229        return self._each(fun=(lambda r: r.as_path_dict()), stage=stage)

Get a dict of paths for a specific stage, and indexed by target

def as_path( self, target: cpg_flow.targets.target.Target, stage: Callable, key: str | None = None) -> cloudpathlib.cloudpath.CloudPath | pathlib.Path:
248    def as_path(
249        self,
250        target: Target,
251        stage: StageDecorator,
252        key: str | None = None,
253    ) -> Path:
254        """
255        Represent as a path to a file, otherwise fail.
256        `stage` can be callable, or a subclass of Stage
257        """
258        res = self._get(target=target, stage=stage)
259        return res.as_path(key)

Represent as a path to a file, otherwise fail. stage can be callable, or a subclass of Stage

def as_str( self, target: cpg_flow.targets.target.Target, stage: Callable, key: str | None = None) -> str:
261    def as_str(
262        self,
263        target: Target,
264        stage: StageDecorator,
265        key: str | None = None,
266    ) -> str:
267        """
268        Represent as a simple string, otherwise fail.
269        `stage` can be callable, or a subclass of Stage
270        """
271        res = self._get(target=target, stage=stage)
272        return res.as_str(key)

Represent as a simple string, otherwise fail. stage can be callable, or a subclass of Stage

def as_dict( self, target: cpg_flow.targets.target.Target, stage: Callable) -> dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path]:
274    def as_dict(self, target: Target, stage: StageDecorator) -> dict[str, Path]:
275        """
276        Get a dict of paths for a specific target and stage
277        """
278        res = self._get(target=target, stage=stage)
279        return res.as_dict()

Get a dict of paths for a specific target and stage

def get_jobs( self, target: cpg_flow.targets.target.Target) -> list[hailtop.batch.job.Job]:
281    def get_jobs(self, target: Target) -> list[Job]:
282        """
283        Get list of jobs that the next stage would depend on.
284        """
285        all_jobs: list[Job] = []
286        target_sequencing_groups = target.get_sequencing_group_ids()
287        for stage_, outputs_by_target in self._outputs_by_target_by_stage.items():
288            for target_, output in outputs_by_target.items():
289                if output:
290                    output_sequencing_groups = output.target.get_sequencing_group_ids()
291                    sequencing_groups_intersect = set(target_sequencing_groups) & set(
292                        output_sequencing_groups,
293                    )
294                    if sequencing_groups_intersect:
295                        for j in output.jobs:
296                            assert j, f'Stage: {stage_}, target: {target_}, output: {output}'
297                        all_jobs.extend(output.jobs)
298        return all_jobs

Get list of jobs that the next stage would depend on.

class Stage(typing.Generic[~TargetT], abc.ABC):
301class Stage(Generic[TargetT], ABC):
302    """
303    Abstract class for a workflow stage. Parametrised by specific Target subclass,
304    i.e. SequencingGroupStage(Stage[SequencingGroup]) should only be able to work on SequencingGroup(Target).
305    """
306
307    def __init__(
308        self,
309        name: str,
310        required_stages: list[StageDecorator] | StageDecorator | None = None,
311        analysis_type: str | None = None,
312        analysis_keys: list[str] | None = None,
313        update_analysis_meta: Callable[[str], dict] | None = None,
314        tolerate_missing_output: bool = False,
315        skipped: bool = False,
316        assume_outputs_exist: bool = False,
317        forced: bool = False,
318    ):
319        self._name = name
320        self.required_stages_classes: list[StageDecorator] = []
321        if required_stages:
322            if isinstance(required_stages, list):
323                self.required_stages_classes.extend(required_stages)
324            else:
325                self.required_stages_classes.append(required_stages)
326
327        # Dependencies. Populated in workflow.run(), after we know all stages.
328        self.required_stages: list[Stage] = []
329
330        self.status_reporter = get_workflow().status_reporter
331        # If `analysis_type` is defined, it will be used to create/update Analysis
332        # entries in Metamist.
333        self.analysis_type = analysis_type
334        # If `analysis_keys` are defined, it will be used to extract the value for
335        # `Analysis.output` if the Stage.expected_outputs() returns a dict.
336        self.analysis_keys = analysis_keys
337        # if `update_analysis_meta` is defined, it is called on the `Analysis.output`
338        # field, and result is merged into the `Analysis.meta` dictionary.
339        self.update_analysis_meta = update_analysis_meta
340
341        self.tolerate_missing_output = tolerate_missing_output
342
343        # Populated with the return value of `add_to_the_workflow()`
344        self.output_by_target: dict[str, StageOutput | None] = dict()
345
346        self.skipped = skipped
347        self.forced = forced or self.name in get_config()['workflow'].get(
348            'force_stages',
349            [],
350        )
351        self.assume_outputs_exist = assume_outputs_exist
352
353    @property
354    def tmp_prefix(self):
355        return get_workflow().tmp_prefix / self.name
356
357    @property
358    def web_prefix(self) -> Path:
359        return get_workflow().web_prefix / self.name
360
361    @property
362    def prefix(self) -> Path:
363        return get_workflow().prefix / self.name
364
365    @property
366    def analysis_prefix(self) -> Path:
367        return get_workflow().analysis_prefix / self.name
368
369    def get_stage_cohort_prefix(
370        self,
371        cohort: Cohort,
372        category: str | None = None,
373    ) -> Path:
374        """
375        Takes a cohort as an argument, calls through to the Workflow cohort_prefix method
376        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID / STAGE_NAME
377        e.g. "gs://cpg-project-main/seqr_loader/COH123/MyStage"
378
379        Args:
380            cohort (Cohort): we pull the analysis dataset and name from this Cohort
381            category (str | none): main, tmp, test, analysis, web
382
383        Returns:
384            Path
385        """
386        return get_workflow().cohort_prefix(cohort, category=category) / self.name
387
388    def __str__(self):
389        res = f'{self._name}'
390        if self.skipped:
391            res += ' [skipped]'
392        if self.forced:
393            res += ' [forced]'
394        if self.assume_outputs_exist:
395            res += ' [assume_outputs_exist]'
396        if self.required_stages:
397            res += f' <- [{", ".join([s.name for s in self.required_stages])}]'
398        return res
399
400    @property
401    def name(self) -> str:
402        """
403        Stage name (unique and descriptive stage)
404        """
405        return self._name
406
407    @abstractmethod
408    def queue_jobs(self, target: TargetT, inputs: StageInput) -> StageOutput | None:
409        """
410        Adds Hail Batch jobs that process `target`.
411        Assumes that all the household work is done: checking missing inputs
412        from required stages, checking for possible reuse of existing outputs.
413        """
414
415    @abstractmethod
416    def expected_outputs(self, target: TargetT) -> ExpectedResultT:
417        """
418        Get path(s) to files that the stage is expected to generate for a `target`.
419        Used within in `queue_jobs()` to pass paths to outputs to job commands,
420        as well as by the workflow to check if the stage's expected outputs already
421        exist and can be reused.
422
423        Can be a str, a Path object, or a dictionary of str/Path objects.
424        """
425
426    # TODO: remove this method
427    def deprecated_queue_for_cohort(
428        self,
429        cohort: Cohort,
430    ) -> dict[str, StageOutput | None]:
431        """
432        Queues jobs for each corresponding target, defined by Stage subclass.
433        Returns a dictionary of `StageOutput` objects indexed by target unique_id.
434        unused, ready for deletion
435        """
436        return {}
437
438    @abstractmethod
439    def queue_for_multicohort(
440        self,
441        multicohort: MultiCohort,
442    ) -> dict[str, StageOutput | None]:
443        """
444        Queues jobs for each corresponding target, defined by Stage subclass.
445
446        Returns a dictionary of `StageOutput` objects indexed by target unique_id.
447        """
448
449    def _make_inputs(self) -> StageInput:
450        """
451        Collects outputs from all dependencies and create input for this stage
452        """
453        inputs = StageInput(self)
454        for prev_stage in self.required_stages:
455            for _, stage_output in prev_stage.output_by_target.items():
456                if stage_output:
457                    inputs.add_other_stage_output(stage_output)
458        return inputs
459
460    def make_outputs(
461        self,
462        target: Target,
463        data: ExpectedResultT = None,  # TODO: ExpectedResultT is probably too broad, our code only really support dict
464        jobs: Sequence[Job | None] | Job | None = None,
465        meta: dict | None = None,
466        reusable: bool = False,
467        skipped: bool = False,
468        error_msg: str | None = None,
469    ) -> StageOutput:
470        """
471        Create StageOutput for this stage.
472        """
473        return StageOutput(
474            target=target,
475            data=data,
476            jobs=jobs,
477            meta=meta,
478            reusable=reusable,
479            skipped=skipped,
480            error_msg=error_msg,
481            stage=self,
482        )
483
484    def _queue_jobs_with_checks(
485        self,
486        target: TargetT,
487        action: Action | None = None,
488    ) -> StageOutput | None:
489        """
490        Checks what to do with target, and either queue jobs, or skip/reuse results.
491        """
492        if not action:
493            action = self._get_action(target)
494
495        inputs = self._make_inputs()
496        expected_out = self.expected_outputs(target)
497
498        if action == Action.QUEUE:
499            outputs = self.queue_jobs(target, inputs)
500        elif action == Action.REUSE:
501            outputs = self.make_outputs(
502                target=target,
503                data=expected_out,
504                reusable=True,
505            )
506        else:  # Action.SKIP
507            outputs = None
508
509        if not outputs:
510            return None
511
512        outputs.stage = self
513        outputs.meta |= self.get_job_attrs(target)
514
515        for output_job in outputs.jobs:
516            if output_job:
517                for input_job in inputs.get_jobs(target):
518                    assert input_job, f'Input dependency job for stage: {self}, target: {target}'
519                    output_job.depends_on(input_job)
520
521        if outputs.error_msg:
522            return outputs
523
524        # Adding status reporter jobs
525        if self.analysis_type and self.status_reporter and action == Action.QUEUE and outputs.data:
526            analysis_outputs: list[str | Path] = []
527            if isinstance(outputs.data, dict):
528                if not self.analysis_keys:
529                    raise WorkflowError(
530                        f'Cannot create Analysis: `analysis_keys` '
531                        f'must be set with the @stage decorator to select value from '
532                        f'the expected_outputs dict: {outputs.data}',
533                    )
534                if not all(key in outputs.data for key in self.analysis_keys):
535                    raise WorkflowError(
536                        f'Cannot create Analysis for stage {self.name}: `analysis_keys` '
537                        f'"{self.analysis_keys}" is not a subset of the expected_outputs '
538                        f'keys {outputs.data.keys()}',
539                    )
540
541                for analysis_key in self.analysis_keys:
542                    analysis_outputs.append(outputs.data[analysis_key])
543
544            else:
545                analysis_outputs.append(outputs.data)
546
547            project_name = None
548            if isinstance(target, SequencingGroup):
549                project_name = target.dataset.name
550            elif isinstance(target, Dataset):
551                project_name = target.name
552            elif isinstance(target, Cohort | MultiCohort):
553                project_name = target.analysis_dataset.name
554
555            assert isinstance(project_name, str)
556
557            # bump name to include `-test`
558            if get_config()['workflow']['access_level'] == 'test' and 'test' not in project_name:
559                project_name = f'{project_name}-test'
560
561            for analysis_output in analysis_outputs:
562                if not outputs.jobs:
563                    continue
564
565                assert isinstance(
566                    analysis_output,
567                    str | Path,
568                ), f'{analysis_output} should be a str or Path object'
569                if outputs.meta is None:
570                    outputs.meta = {}
571
572                self.status_reporter.create_analysis(
573                    b=get_batch(),
574                    output=str(analysis_output),
575                    analysis_type=self.analysis_type,
576                    target=target,
577                    jobs=outputs.jobs,
578                    job_attr=self.get_job_attrs(target) | {'stage': self.name, 'tool': 'metamist'},
579                    meta=outputs.meta,
580                    update_analysis_meta=self.update_analysis_meta,
581                    tolerate_missing_output=self.tolerate_missing_output,
582                    project_name=project_name,
583                )
584
585        return outputs
586
587    def _get_action(self, target: TargetT) -> Action:
588        """
589        Based on stage parameters and expected outputs existence, determines what
590        to do with the target: queue, skip or reuse, etc...
591        """
592        if target.forced and not self.skipped:
593            LOGGER.info(f'{self.name}: {target} [QUEUE] (target is forced)')
594            return Action.QUEUE
595
596        if (d := get_config()['workflow'].get('skip_stages_for_sgs')) and self.name in d:
597            skip_targets = d[self.name]
598            if target.target_id in skip_targets:
599                LOGGER.info(
600                    f'{self.name}: {target} [SKIP] (is in workflow/skip_stages_for_sgs)',
601                )
602                return Action.SKIP
603
604        expected_out = self.expected_outputs(target)
605        reusable, first_missing_path = self._is_reusable(expected_out)
606
607        if self.skipped:
608            if reusable and not first_missing_path:
609                LOGGER.debug(
610                    f'{self.name}: {target} [REUSE] (stage skipped, and outputs exist)',
611                )
612                return Action.REUSE
613            if get_config()['workflow'].get('skip_sgs_with_missing_input'):
614                LOGGER.warning(
615                    f'{self.name}: {target} [SKIP] (stage is required, '
616                    f'but is marked as "skipped", '
617                    f'workflow/skip_sgs_with_missing_input=true '
618                    f'and some expected outputs for the target do not exist: '
619                    f'{first_missing_path}',
620                )
621                # `workflow/skip_sgs_with_missing_input` means that we can ignore
622                # sgs/datasets that have missing results from skipped stages.
623                # This is our case, so indicating that this sg/dataset should
624                # be ignored:
625                target.active = False
626                return Action.SKIP
627            if self.name in get_config()['workflow'].get(
628                'allow_missing_outputs_for_stages',
629                [],
630            ):
631                LOGGER.info(
632                    f'{self.name}: {target} [REUSE] (stage is skipped, some outputs are'
633                    f'missing, but stage is listed in '
634                    f'workflow/allow_missing_outputs_for_stages)',
635                )
636                return Action.REUSE
637            raise WorkflowError(
638                f'{self.name}: stage is required, but is skipped, and '
639                f'the following expected outputs for target {target} do not exist: '
640                f'{first_missing_path}',
641            )
642
643        if reusable and not first_missing_path:
644            if target.forced:
645                LOGGER.info(
646                    f'{self.name}: {target} [QUEUE] (can reuse, but forcing the target to rerun this stage)',
647                )
648                return Action.QUEUE
649            if self.forced:
650                LOGGER.info(
651                    f'{self.name}: {target} [QUEUE] (can reuse, but forcing the stage to rerun)',
652                )
653                return Action.QUEUE
654            LOGGER.info(
655                f'{self.name}: {target} [REUSE] (expected outputs exist: {expected_out})',
656            )
657            return Action.REUSE
658
659        LOGGER.info(f'{self.name}: {target} [QUEUE]')
660
661        return Action.QUEUE
662
663    def _is_reusable(self, expected_out: ExpectedResultT) -> tuple[bool, Path | None]:
664        """
665        Checks if the outputs of prior stages already exist, and can be reused
666        Args:
667            expected_out (ExpectedResultT): expected outputs of a stage
668
669        Returns:
670            tuple[bool, Path | None]:
671                bool: True if the outputs can be reused, False otherwise
672                Path | None: first missing path, if any
673        """
674        if self.assume_outputs_exist:
675            LOGGER.debug(f'Assuming outputs exist. Expected output is {expected_out}')
676            return True, None
677
678        if not expected_out:
679            # Marking is reusable. If the stage does not naturally produce any outputs,
680            # it would still need to create some flag file.
681            LOGGER.debug('No expected outputs, assuming outputs exist')
682            return True, None
683
684        if get_config()['workflow'].get('check_expected_outputs'):
685            paths = path_walk(expected_out)
686            LOGGER.info(
687                f'Checking if {paths} from expected output {expected_out} exist',
688            )
689            if not paths:
690                LOGGER.info(f'{expected_out} is not reusable. No paths found.')
691                return False, None
692
693            if first_missing_path := next((p for p in paths if not exists(p)), None):
694                LOGGER.info(
695                    f'{expected_out} is not reusable, {first_missing_path} is missing',
696                )
697                return False, first_missing_path
698
699            return True, None
700        if self.skipped:
701            # Do not check the files' existence, trust they exist.
702            # note that for skipped stages, we automatically assume outputs exist
703            return True, None
704        # Do not check the files' existence, assume they don't exist:
705        return False, None
706
707    def get_job_attrs(self, target: TargetT | None = None) -> dict[str, str]:
708        """
709        Create Hail Batch Job attributes dictionary
710        """
711        job_attrs = dict(stage=self.name)
712        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
713            job_attrs['sequencing_type'] = sequencing_type
714        if target:
715            job_attrs |= target.get_job_attrs()
716        return job_attrs

Abstract class for a workflow stage. Parametrised by specific Target subclass, i.e. SequencingGroupStage(Stage[SequencingGroup]) should only be able to work on SequencingGroup(Target).

required_stages_classes: list[Callable[..., Stage]]
required_stages: list[Stage]
status_reporter
analysis_type
analysis_keys
update_analysis_meta
tolerate_missing_output
output_by_target: dict[str, StageOutput | None]
skipped
forced
assume_outputs_exist
tmp_prefix
353    @property
354    def tmp_prefix(self):
355        return get_workflow().tmp_prefix / self.name
web_prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
357    @property
358    def web_prefix(self) -> Path:
359        return get_workflow().web_prefix / self.name
prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
361    @property
362    def prefix(self) -> Path:
363        return get_workflow().prefix / self.name
analysis_prefix: cloudpathlib.cloudpath.CloudPath | pathlib.Path
365    @property
366    def analysis_prefix(self) -> Path:
367        return get_workflow().analysis_prefix / self.name
def get_stage_cohort_prefix( self, cohort: cpg_flow.targets.cohort.Cohort, category: str | None = None) -> cloudpathlib.cloudpath.CloudPath | pathlib.Path:
369    def get_stage_cohort_prefix(
370        self,
371        cohort: Cohort,
372        category: str | None = None,
373    ) -> Path:
374        """
375        Takes a cohort as an argument, calls through to the Workflow cohort_prefix method
376        Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID / STAGE_NAME
377        e.g. "gs://cpg-project-main/seqr_loader/COH123/MyStage"
378
379        Args:
380            cohort (Cohort): we pull the analysis dataset and name from this Cohort
381            category (str | none): main, tmp, test, analysis, web
382
383        Returns:
384            Path
385        """
386        return get_workflow().cohort_prefix(cohort, category=category) / self.name

Takes a cohort as an argument, calls through to the Workflow cohort_prefix method Result in the form PROJECT_BUCKET / WORKFLOW_NAME / COHORT_ID / STAGE_NAME e.g. "gs://cpg-project-main/seqr_loader/COH123/MyStage"

Args: cohort (Cohort): we pull the analysis dataset and name from this Cohort category (str | none): main, tmp, test, analysis, web

Returns: Path

name: str
400    @property
401    def name(self) -> str:
402        """
403        Stage name (unique and descriptive stage)
404        """
405        return self._name

Stage name (unique and descriptive stage)

@abstractmethod
def queue_jobs( self, target: ~TargetT, inputs: StageInput) -> StageOutput | None:
407    @abstractmethod
408    def queue_jobs(self, target: TargetT, inputs: StageInput) -> StageOutput | None:
409        """
410        Adds Hail Batch jobs that process `target`.
411        Assumes that all the household work is done: checking missing inputs
412        from required stages, checking for possible reuse of existing outputs.
413        """

Adds Hail Batch jobs that process target. Assumes that all the household work is done: checking missing inputs from required stages, checking for possible reuse of existing outputs.

@abstractmethod
def expected_outputs( self, target: ~TargetT) -> Union[cloudpathlib.cloudpath.CloudPath, pathlib.Path, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path], dict[str, str], str, NoneType]:
415    @abstractmethod
416    def expected_outputs(self, target: TargetT) -> ExpectedResultT:
417        """
418        Get path(s) to files that the stage is expected to generate for a `target`.
419        Used within in `queue_jobs()` to pass paths to outputs to job commands,
420        as well as by the workflow to check if the stage's expected outputs already
421        exist and can be reused.
422
423        Can be a str, a Path object, or a dictionary of str/Path objects.
424        """

Get path(s) to files that the stage is expected to generate for a target. Used within in queue_jobs() to pass paths to outputs to job commands, as well as by the workflow to check if the stage's expected outputs already exist and can be reused.

Can be a str, a Path object, or a dictionary of str/Path objects.

def deprecated_queue_for_cohort( self, cohort: cpg_flow.targets.cohort.Cohort) -> dict[str, StageOutput | None]:
427    def deprecated_queue_for_cohort(
428        self,
429        cohort: Cohort,
430    ) -> dict[str, StageOutput | None]:
431        """
432        Queues jobs for each corresponding target, defined by Stage subclass.
433        Returns a dictionary of `StageOutput` objects indexed by target unique_id.
434        unused, ready for deletion
435        """
436        return {}

Queues jobs for each corresponding target, defined by Stage subclass. Returns a dictionary of StageOutput objects indexed by target unique_id. unused, ready for deletion

@abstractmethod
def queue_for_multicohort( self, multicohort: cpg_flow.targets.multicohort.MultiCohort) -> dict[str, StageOutput | None]:
438    @abstractmethod
439    def queue_for_multicohort(
440        self,
441        multicohort: MultiCohort,
442    ) -> dict[str, StageOutput | None]:
443        """
444        Queues jobs for each corresponding target, defined by Stage subclass.
445
446        Returns a dictionary of `StageOutput` objects indexed by target unique_id.
447        """

Queues jobs for each corresponding target, defined by Stage subclass.

Returns a dictionary of StageOutput objects indexed by target unique_id.

def make_outputs( self, target: cpg_flow.targets.target.Target, data: Union[cloudpathlib.cloudpath.CloudPath, pathlib.Path, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path], dict[str, str], str, NoneType] = None, jobs: Sequence[hailtop.batch.job.Job | None] | hailtop.batch.job.Job | None = None, meta: dict | None = None, reusable: bool = False, skipped: bool = False, error_msg: str | None = None) -> StageOutput:
460    def make_outputs(
461        self,
462        target: Target,
463        data: ExpectedResultT = None,  # TODO: ExpectedResultT is probably too broad, our code only really support dict
464        jobs: Sequence[Job | None] | Job | None = None,
465        meta: dict | None = None,
466        reusable: bool = False,
467        skipped: bool = False,
468        error_msg: str | None = None,
469    ) -> StageOutput:
470        """
471        Create StageOutput for this stage.
472        """
473        return StageOutput(
474            target=target,
475            data=data,
476            jobs=jobs,
477            meta=meta,
478            reusable=reusable,
479            skipped=skipped,
480            error_msg=error_msg,
481            stage=self,
482        )

Create StageOutput for this stage.

def get_job_attrs(self, target: Optional[~TargetT] = None) -> dict[str, str]:
707    def get_job_attrs(self, target: TargetT | None = None) -> dict[str, str]:
708        """
709        Create Hail Batch Job attributes dictionary
710        """
711        job_attrs = dict(stage=self.name)
712        if sequencing_type := get_config()['workflow'].get('sequencing_type'):
713            job_attrs['sequencing_type'] = sequencing_type
714        if target:
715            job_attrs |= target.get_job_attrs()
716        return job_attrs

Create Hail Batch Job attributes dictionary

def stage( cls: type[Stage] | None = None, *, analysis_type: str | None = None, analysis_keys: list[str | cloudpathlib.cloudpath.CloudPath | pathlib.Path] | None = None, update_analysis_meta: Callable[[str], dict] | None = None, tolerate_missing_output: bool = False, required_stages: list[Callable[..., Stage]] | Callable[..., Stage] | None = None, skipped: bool = False, assume_outputs_exist: bool = False, forced: bool = False) -> Callable[..., Stage] | Callable[..., Callable[..., Stage]]:
719def stage(
720    cls: type['Stage'] | None = None,
721    *,
722    analysis_type: str | None = None,
723    analysis_keys: list[str | Path] | None = None,
724    update_analysis_meta: Callable[[str], dict] | None = None,
725    tolerate_missing_output: bool = False,
726    required_stages: list[StageDecorator] | StageDecorator | None = None,
727    skipped: bool = False,
728    assume_outputs_exist: bool = False,
729    forced: bool = False,
730) -> StageDecorator | Callable[..., StageDecorator]:
731    """
732    Implements a standard class decorator pattern with optional arguments.
733    The goal is to allow declaring workflow stages without requiring to implement
734    a constructor method. E.g.
735
736    @stage(required_stages=[Align])
737    class GenotypeSample(SequencingGroupStage):
738        def expected_outputs(self, sequencing_group: SequencingGroup):
739            ...
740        def queue_jobs(self, sequencing_group: SequencingGroup, inputs: StageInput) -> StageOutput:
741            ...
742
743    @analysis_type: if defined, will be used to create/update `Analysis` entries
744        using the status reporter.
745    @analysis_keys: is defined, will be used to extract the value for `Analysis.output`
746        if the Stage.expected_outputs() returns a dict.
747    @update_analysis_meta: if defined, this function is called on the `Analysis.output`
748        field, and returns a dictionary to be merged into the `Analysis.meta`
749    @tolerate_missing_output: if True, when registering the output of this stage,
750        allow for the output file to be missing (only relevant for metamist entry)
751    @required_stages: list of other stage classes that are required prerequisites
752        for this stage. Outputs of those stages will be passed to
753        `Stage.queue_jobs(... , inputs)` as `inputs`, and all required
754        dependencies between Hail Batch jobs will be set automatically as well.
755    @skipped: always skip this stage.
756    @assume_outputs_exist: assume expected outputs of this stage always exist.
757    @forced: always force run this stage, regardless of the outputs' existence.
758    """
759
760    def decorator_stage(_cls) -> StageDecorator:
761        """Implements decorator."""
762
763        @functools.wraps(_cls)
764        def wrapper_stage() -> Stage:
765            """Decorator helper function."""
766            return _cls(
767                name=_cls.__name__,
768                required_stages=required_stages,
769                analysis_type=analysis_type,
770                analysis_keys=analysis_keys,
771                update_analysis_meta=update_analysis_meta,
772                skipped=skipped,
773                assume_outputs_exist=assume_outputs_exist,
774                forced=forced,
775                tolerate_missing_output=tolerate_missing_output,
776            )
777
778        return wrapper_stage
779
780    if cls is None:
781        return decorator_stage
782    return decorator_stage(cls)

Implements a standard class decorator pattern with optional arguments. The goal is to allow declaring workflow stages without requiring to implement a constructor method. E.g.

@stage(required_stages=[Align]) class GenotypeSample(SequencingGroupStage): def expected_outputs(self, sequencing_group: SequencingGroup): ... def queue_jobs(self, sequencing_group: SequencingGroup, inputs: StageInput) -> StageOutput: ...

@analysis_type: if defined, will be used to create/update Analysis entries using the status reporter. @analysis_keys: is defined, will be used to extract the value for Analysis.output if the Stage.expected_outputs() returns a dict. @update_analysis_meta: if defined, this function is called on the Analysis.output field, and returns a dictionary to be merged into the Analysis.meta @tolerate_missing_output: if True, when registering the output of this stage, allow for the output file to be missing (only relevant for metamist entry) @required_stages: list of other stage classes that are required prerequisites for this stage. Outputs of those stages will be passed to Stage.queue_jobs(... , inputs) as inputs, and all required dependencies between Hail Batch jobs will be set automatically as well. @skipped: always skip this stage. @assume_outputs_exist: assume expected outputs of this stage always exist. @forced: always force run this stage, regardless of the outputs' existence.

class SequencingGroupStage(cpg_flow.stage.Stage[cpg_flow.targets.sequencing_group.SequencingGroup], abc.ABC):
785class SequencingGroupStage(Stage[SequencingGroup], ABC):
786    """
787    Sequencing Group level stage.
788    """
789
790    @abstractmethod
791    def expected_outputs(self, sequencing_group: SequencingGroup) -> ExpectedResultT:
792        """
793        Override to declare expected output paths.
794        """
795
796    @abstractmethod
797    def queue_jobs(
798        self,
799        sequencing_group: SequencingGroup,
800        inputs: StageInput,
801    ) -> StageOutput | None:
802        """
803        Override to add Hail Batch jobs.
804        """
805
806    def queue_for_multicohort(
807        self,
808        multicohort: MultiCohort,
809    ) -> dict[str, StageOutput | None]:
810        """
811        Plug the stage into the workflow.
812        """
813        output_by_target: dict[str, StageOutput | None] = dict()
814        if not (active_sgs := multicohort.get_sequencing_groups()):
815            all_sgs = len(multicohort.get_sequencing_groups(only_active=False))
816            LOGGER.warning(
817                f'{len(active_sgs)}/{all_sgs} usable (active=True) SGs found in the multicohort. '
818                'Check that input_cohorts` or `input_datasets` are provided and not skipped',
819            )
820            return output_by_target
821
822        # evaluate_stuff en masse
823        for sequencing_group in active_sgs:
824            action = self._get_action(sequencing_group)
825            output_by_target[sequencing_group.target_id] = self._queue_jobs_with_checks(
826                sequencing_group,
827                action,
828            )
829        return output_by_target

Sequencing Group level stage.

@abstractmethod
def expected_outputs( self, sequencing_group: cpg_flow.targets.sequencing_group.SequencingGroup) -> Union[cloudpathlib.cloudpath.CloudPath, pathlib.Path, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path], dict[str, str], str, NoneType]:
790    @abstractmethod
791    def expected_outputs(self, sequencing_group: SequencingGroup) -> ExpectedResultT:
792        """
793        Override to declare expected output paths.
794        """

Override to declare expected output paths.

@abstractmethod
def queue_jobs( self, sequencing_group: cpg_flow.targets.sequencing_group.SequencingGroup, inputs: StageInput) -> StageOutput | None:
796    @abstractmethod
797    def queue_jobs(
798        self,
799        sequencing_group: SequencingGroup,
800        inputs: StageInput,
801    ) -> StageOutput | None:
802        """
803        Override to add Hail Batch jobs.
804        """

Override to add Hail Batch jobs.

def queue_for_multicohort( self, multicohort: cpg_flow.targets.multicohort.MultiCohort) -> dict[str, StageOutput | None]:
806    def queue_for_multicohort(
807        self,
808        multicohort: MultiCohort,
809    ) -> dict[str, StageOutput | None]:
810        """
811        Plug the stage into the workflow.
812        """
813        output_by_target: dict[str, StageOutput | None] = dict()
814        if not (active_sgs := multicohort.get_sequencing_groups()):
815            all_sgs = len(multicohort.get_sequencing_groups(only_active=False))
816            LOGGER.warning(
817                f'{len(active_sgs)}/{all_sgs} usable (active=True) SGs found in the multicohort. '
818                'Check that input_cohorts` or `input_datasets` are provided and not skipped',
819            )
820            return output_by_target
821
822        # evaluate_stuff en masse
823        for sequencing_group in active_sgs:
824            action = self._get_action(sequencing_group)
825            output_by_target[sequencing_group.target_id] = self._queue_jobs_with_checks(
826                sequencing_group,
827                action,
828            )
829        return output_by_target

Plug the stage into the workflow.

class DatasetStage(typing.Generic[~TargetT], abc.ABC):
832class DatasetStage(Stage, ABC):
833    """
834    Dataset-level stage
835    """
836
837    @abstractmethod
838    def expected_outputs(self, dataset: Dataset) -> ExpectedResultT:
839        """
840        Override to declare expected output paths.
841        """
842
843    @abstractmethod
844    def queue_jobs(self, dataset: Dataset, inputs: StageInput) -> StageOutput | None:
845        """
846        Override to add Hail Batch jobs.
847        """
848
849    def queue_for_multicohort(
850        self,
851        multicohort: MultiCohort,
852    ) -> dict[str, StageOutput | None]:
853        """
854        Plug the stage into the workflow.
855        """
856        output_by_target: dict[str, StageOutput | None] = dict()
857        # iterate directly over the datasets in this multicohort
858        for dataset_i, dataset in enumerate(multicohort.get_datasets()):
859            action = self._get_action(dataset)
860            LOGGER.info(f'{self.name}: #{dataset_i + 1}/{dataset} [{action.name}]')
861            output_by_target[dataset.target_id] = self._queue_jobs_with_checks(
862                dataset,
863                action,
864            )
865        return output_by_target

Dataset-level stage

@abstractmethod
def expected_outputs( self, dataset: cpg_flow.targets.dataset.Dataset) -> Union[cloudpathlib.cloudpath.CloudPath, pathlib.Path, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path], dict[str, str], str, NoneType]:
837    @abstractmethod
838    def expected_outputs(self, dataset: Dataset) -> ExpectedResultT:
839        """
840        Override to declare expected output paths.
841        """

Override to declare expected output paths.

@abstractmethod
def queue_jobs( self, dataset: cpg_flow.targets.dataset.Dataset, inputs: StageInput) -> StageOutput | None:
843    @abstractmethod
844    def queue_jobs(self, dataset: Dataset, inputs: StageInput) -> StageOutput | None:
845        """
846        Override to add Hail Batch jobs.
847        """

Override to add Hail Batch jobs.

def queue_for_multicohort( self, multicohort: cpg_flow.targets.multicohort.MultiCohort) -> dict[str, StageOutput | None]:
849    def queue_for_multicohort(
850        self,
851        multicohort: MultiCohort,
852    ) -> dict[str, StageOutput | None]:
853        """
854        Plug the stage into the workflow.
855        """
856        output_by_target: dict[str, StageOutput | None] = dict()
857        # iterate directly over the datasets in this multicohort
858        for dataset_i, dataset in enumerate(multicohort.get_datasets()):
859            action = self._get_action(dataset)
860            LOGGER.info(f'{self.name}: #{dataset_i + 1}/{dataset} [{action.name}]')
861            output_by_target[dataset.target_id] = self._queue_jobs_with_checks(
862                dataset,
863                action,
864            )
865        return output_by_target

Plug the stage into the workflow.

class CohortStage(typing.Generic[~TargetT], abc.ABC):
868class CohortStage(Stage, ABC):
869    """
870    Cohort-level stage (all datasets of a workflow run).
871    """
872
873    @abstractmethod
874    def expected_outputs(self, cohort: Cohort) -> ExpectedResultT:
875        """
876        Override to declare expected output paths.
877        """
878
879    @abstractmethod
880    def queue_jobs(self, cohort: Cohort, inputs: StageInput) -> StageOutput | None:
881        """
882        Override to add Hail Batch jobs.
883        """
884
885    def queue_for_multicohort(
886        self,
887        multicohort: MultiCohort,
888    ) -> dict[str, StageOutput | None]:
889        """
890        Plug the stage into the workflow.
891        """
892        output_by_target: dict[str, StageOutput | None] = dict()
893        for cohort in multicohort.get_cohorts():
894            action = self._get_action(cohort)
895            LOGGER.info(f'{self.name}: {cohort} [{action.name}]')
896            output_by_target[cohort.target_id] = self._queue_jobs_with_checks(
897                cohort,
898                action,
899            )
900        return output_by_target

Cohort-level stage (all datasets of a workflow run).

@abstractmethod
def expected_outputs( self, cohort: cpg_flow.targets.cohort.Cohort) -> Union[cloudpathlib.cloudpath.CloudPath, pathlib.Path, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path], dict[str, str], str, NoneType]:
873    @abstractmethod
874    def expected_outputs(self, cohort: Cohort) -> ExpectedResultT:
875        """
876        Override to declare expected output paths.
877        """

Override to declare expected output paths.

@abstractmethod
def queue_jobs( self, cohort: cpg_flow.targets.cohort.Cohort, inputs: StageInput) -> StageOutput | None:
879    @abstractmethod
880    def queue_jobs(self, cohort: Cohort, inputs: StageInput) -> StageOutput | None:
881        """
882        Override to add Hail Batch jobs.
883        """

Override to add Hail Batch jobs.

def queue_for_multicohort( self, multicohort: cpg_flow.targets.multicohort.MultiCohort) -> dict[str, StageOutput | None]:
885    def queue_for_multicohort(
886        self,
887        multicohort: MultiCohort,
888    ) -> dict[str, StageOutput | None]:
889        """
890        Plug the stage into the workflow.
891        """
892        output_by_target: dict[str, StageOutput | None] = dict()
893        for cohort in multicohort.get_cohorts():
894            action = self._get_action(cohort)
895            LOGGER.info(f'{self.name}: {cohort} [{action.name}]')
896            output_by_target[cohort.target_id] = self._queue_jobs_with_checks(
897                cohort,
898                action,
899            )
900        return output_by_target

Plug the stage into the workflow.

class MultiCohortStage(typing.Generic[~TargetT], abc.ABC):
903class MultiCohortStage(Stage, ABC):
904    """
905    MultiCohort-level stage (all datasets of a workflow run).
906    """
907
908    @abstractmethod
909    def expected_outputs(self, multicohort: MultiCohort) -> ExpectedResultT:
910        """
911        Override to declare expected output paths.
912        """
913
914    @abstractmethod
915    def queue_jobs(
916        self,
917        multicohort: MultiCohort,
918        inputs: StageInput,
919    ) -> StageOutput | None:
920        """
921        Override to add Hail Batch jobs.
922        """
923
924    def queue_for_multicohort(
925        self,
926        multicohort: MultiCohort,
927    ) -> dict[str, StageOutput | None]:
928        """
929        Plug the stage into the workflow.
930        """
931        output_by_target: dict[str, StageOutput | None] = dict()
932        action = self._get_action(multicohort)
933        LOGGER.info(f'{self.name}: {multicohort} [{action.name}]')
934        output_by_target[multicohort.target_id] = self._queue_jobs_with_checks(
935            multicohort,
936            action,
937        )
938        return output_by_target

MultiCohort-level stage (all datasets of a workflow run).

@abstractmethod
def expected_outputs( self, multicohort: cpg_flow.targets.multicohort.MultiCohort) -> Union[cloudpathlib.cloudpath.CloudPath, pathlib.Path, dict[str, cloudpathlib.cloudpath.CloudPath | pathlib.Path], dict[str, str], str, NoneType]:
908    @abstractmethod
909    def expected_outputs(self, multicohort: MultiCohort) -> ExpectedResultT:
910        """
911        Override to declare expected output paths.
912        """

Override to declare expected output paths.

@abstractmethod
def queue_jobs( self, multicohort: cpg_flow.targets.multicohort.MultiCohort, inputs: StageInput) -> StageOutput | None:
914    @abstractmethod
915    def queue_jobs(
916        self,
917        multicohort: MultiCohort,
918        inputs: StageInput,
919    ) -> StageOutput | None:
920        """
921        Override to add Hail Batch jobs.
922        """

Override to add Hail Batch jobs.

def queue_for_multicohort( self, multicohort: cpg_flow.targets.multicohort.MultiCohort) -> dict[str, StageOutput | None]:
924    def queue_for_multicohort(
925        self,
926        multicohort: MultiCohort,
927    ) -> dict[str, StageOutput | None]:
928        """
929        Plug the stage into the workflow.
930        """
931        output_by_target: dict[str, StageOutput | None] = dict()
932        action = self._get_action(multicohort)
933        LOGGER.info(f'{self.name}: {multicohort} [{action.name}]')
934        output_by_target[multicohort.target_id] = self._queue_jobs_with_checks(
935            multicohort,
936            action,
937        )
938        return output_by_target

Plug the stage into the workflow.