Coverage for src/snakemake_storage_plugin_orcestra/__init__.py: 80%
145 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 17:24 +0100
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-11 17:24 +0100
1import asyncio
2from dataclasses import dataclass, field
3from pathlib import Path
4from typing import TYPE_CHECKING, Any, List, Optional
5from urllib import parse
7from orcestradownloader.dataset_config import DATASET_CONFIG
8from orcestradownloader.logging_config import logger as orcestra_logger
9from orcestradownloader.managers import (
10 REGISTRY,
11 DatasetManager,
12 UnifiedDataManager,
13)
14from orcestradownloader.models.base import BaseModel
16# Raise errors that will not be handled within this plugin but thrown upwards to
17# Snakemake and the user as WorkflowError.
18from snakemake_interface_common.exceptions import WorkflowError # noqa: F401
19from snakemake_interface_storage_plugins.io import IOCacheStorageInterface
20from snakemake_interface_storage_plugins.settings import (
21 StorageProviderSettingsBase,
22)
23from snakemake_interface_storage_plugins.storage_object import (
24 StorageObjectRead,
25 retry_decorator,
26)
27from snakemake_interface_storage_plugins.storage_provider import ( # noqa: F401
28 ExampleQuery,
29 Operation,
30 QueryType,
31 StorageProviderBase,
32 StorageQueryValidationResult,
33)
35if TYPE_CHECKING:
36 from datetime import datetime
39# Register all dataset managers automatically
40for name, config in DATASET_CONFIG.items():
41 manager = DatasetManager(
42 url=config.url,
43 cache_file=config.cache_file,
44 dataset_type=config.dataset_type,
45 )
46 REGISTRY.register(name, manager)
48unified_manager = UnifiedDataManager(REGISTRY, force=True)
51orcestra_logger.setLevel("WARNING")
53for handler in orcestra_logger.handlers[:]:
54 orcestra_logger.removeHandler(handler)
55 handler.close()
58@dataclass
59class StorageProviderSettings(StorageProviderSettingsBase):
60 pass
63# Required:
64# Implementation of your storage provider
65# This class can be empty as the one below.
66# You can however use it to store global information or maintain e.g. a connection
67# pool.
68class StorageProvider(StorageProviderBase):
69 # For compatibility with future changes, you should not overwrite the __init__
70 # method. Instead, use __post_init__ to set additional attributes and initialize
71 # futher stuff.
73 def __post_init__(self) -> None:
74 # This is optional and can be removed if not needed.
75 # Alternatively, you can e.g. prepare a connection to your storage backend here.
76 # and set additional attributes.
77 pass
79 @classmethod
80 def example_queries(cls) -> List[ExampleQuery]:
81 """Return an example queries with description for this storage provider (at
82 least one)."""
83 return [
84 ExampleQuery(
85 query="orcestra://pharmacosets/CCLE_2015",
86 description="Download the CCLE 2015 dataset.",
87 type=QueryType.INPUT,
88 )
89 ]
91 def rate_limiter_key(self, query: str, operation: Operation) -> Any: # noqa: ANN401
92 """Return a key for identifying a rate limiter given a query and an operation.
93 Notes
94 -----
95 Unused in orcestra-downloader
96 """
97 return None
99 def default_max_requests_per_second(self) -> float:
100 """Return the default maximum number of requests per second for this storage
101 provider.
102 Notes
103 -----
104 Unused in orcestra-downloader
105 """
106 return 0.0
108 def use_rate_limiter(self) -> bool:
109 """Return False if no rate limiting is needed for this provider.
110 Notes
111 -----
112 Unused in orcestra-downloader
113 """
114 return False
116 @classmethod
117 def is_valid_query(cls, query: str) -> StorageQueryValidationResult:
118 """Return whether the given query is valid for this storage provider."""
119 # Ensure that also queries containing wildcards (e.g. {sample}) are accepted
120 # and considered valid. The wildcards will be resolved before the storage
121 # object is actually used.
122 datatypes = list(unified_manager.names())
123 errormsg = ""
124 try:
125 parsed_query = parse.urlparse(query)
126 except Exception as e:
127 errormsg = f"cannot be parsed as URL ({e})"
128 else:
129 if parsed_query.scheme != "orcestra":
130 errormsg = (
131 f"Invalid scheme in query '{query}'."
132 f"{parsed_query.scheme} should be 'orcestra'."
133 )
134 elif parsed_query.netloc not in datatypes:
135 errormsg = (
136 f"Invalid netloc in query '{query}'."
137 f"{parsed_query.netloc} should be one of {datatypes}."
138 )
139 elif not parsed_query.path:
140 # remove the slash at the beginning
141 dataset_name = parsed_query.path[1:]
142 # check if there are still slashes in the path
143 if "/" in dataset_name:
144 errormsg = (
145 f"Invalid path in query '{query}'. "
146 f"Format should follow"
147 " 'orcestra://<datatype>/<dataset_name>' but got '{parsed_query}'."
148 )
150 if errormsg:
151 orcestra_logger.error(errormsg)
152 return StorageQueryValidationResult(query, False, errormsg)
154 return StorageQueryValidationResult(query, True, "")
157# Required:
158# Implementation of storage object. If certain methods cannot be supported by your
159# storage (e.g. because it is read-only see
160# snakemake-storage-http for comparison), remove the corresponding base classes
161# from the list of inherited items.
162class StorageObject(StorageObjectRead):
163 # following attributes are inherited from StorageObjectRead:
164 # query = query
165 # keep_local = keep_local
166 # retrieve = retrieve
167 # provider = provider
168 # _overwrite_local_path = None
170 dataset_type: str = field(init=False)
171 dataset_name: str = field(init=False)
172 manager: DatasetManager = field(init=False)
173 dataset_metadata: BaseModel | None = field(init=False)
175 def __post_init__(self) -> None:
176 # This is optional and can be removed if not needed.
177 # Alternatively, you can e.g. prepare a connection to your storage backend here.
178 # and set additional attributes.
179 orcestra_logger.debug(f"StorageObject for query {self.query} created.")
180 orcestra_logger.debug(
181 f"Arguments: {self.keep_local=}, {self.retrieve=}, {self.provider=}, {self._overwrite_local_path=}"
182 )
183 parsed = parse.urlparse(self.query)
184 self.dataset_type = parsed.netloc
185 self.dataset_name = parsed.path.split("/")[1]
187 orcestra_logger.debug(
188 f"Dataset type: {self.dataset_type} and name: {self.dataset_name}"
189 )
191 # initialize manager for this datatype
192 self.manager = unified_manager.registry.get_manager(self.dataset_type)
194 # use unified manager to fetch info
195 try:
196 # Get the current event loop or create one if it doesn't exist
197 loop = asyncio.get_event_loop()
198 if loop.is_running():
199 # If loop is already running, create a new future task and run it directly
200 asyncio.create_task(
201 unified_manager.fetch_by_name(
202 self.dataset_type, force=True
203 )
204 )
205 else:
206 # If no loop is running, run the coroutine directly
207 loop.run_until_complete(
208 unified_manager.fetch_by_name(
209 self.dataset_type, force=True
210 )
211 )
212 except RuntimeError:
213 # Fallback to creating a new loop if needed
214 asyncio.run(
215 unified_manager.fetch_by_name(self.dataset_type, force=True)
216 )
218 try:
219 self.dataset_metadata = self.manager[self.dataset_name]
220 except ValueError:
221 self.dataset_metadata = None
223 async def inventory(self, cache: IOCacheStorageInterface) -> None:
224 """From this file, try to find as much existence and modification date
225 information as possible. Only retrieve that information that comes for free
226 given the current object.
227 """
228 # This is optional and can be left as is
230 # If this is implemented in a storage object, results have to be stored in
231 # the given IOCache object, using self.cache_key() as key.
232 # Optionally, this can take a custom local suffix, needed e.g. when you want
233 # to cache more items than the current query: self.cache_key(local_suffix=...)
234 pass
236 def get_inventory_parent(self) -> Optional[str]:
237 """Return the parent directory of this object."""
238 # this is optional and can be left as is
239 return None
241 def local_suffix(self) -> str:
242 """Return a unique suffix for the local path, determined from self.query."""
243 parsed = parse.urlparse(self.query)
244 return f"{parsed.netloc}{parsed.path}.RDS"
246 def cleanup(self) -> None:
247 """Perform local cleanup of any remainders of the storage object."""
248 # self.local_path() should not be removed, as this is taken care of by
249 # Snakemake.
250 ...
252 # Fallible methods should implement some retry logic.
253 # The easiest way to do this (but not the only one) is to use the retry_decorator
254 # provided by snakemake-interface-storage-plugins.
255 @retry_decorator
256 def exists(self) -> bool:
257 if self.dataset_metadata:
258 return True
259 from difflib import get_close_matches
261 dataset_names = self.manager.names()
263 if self.dataset_name not in dataset_names:
264 errmsg = (
265 f"Dataset {self.dataset_name} not found in {self.dataset_type}."
266 f"Did you mean one of {get_close_matches(self.dataset_name, dataset_names)}?"
267 )
268 orcestra_logger.error(errmsg)
269 return False
271 @retry_decorator
272 def mtime(self) -> float:
273 # return the modification time
274 if self.dataset_metadata is None:
275 # return infinity if no date is available
276 return float("-inf")
278 created_date: datetime | None = self.dataset_metadata.date_created
280 if created_date is None:
281 # return infinity if no date is available
282 return float("-inf")
283 return float(created_date.timestamp())
285 @retry_decorator
286 def size(self) -> int:
287 # return the size in bytes
288 return 0
290 @retry_decorator
291 def retrieve_object(self) -> None:
292 # Ensure that the object is accessible locally under self.local_path()
293 directory_path = Path(self.local_path()).parent
294 if not directory_path.exists():
295 directory_path.mkdir(parents=True)
296 from rich.progress import (
297 Progress,
298 )
300 if self.dataset_metadata is None:
301 errmsg = f"Dataset {self.dataset_name} not found in {self.dataset_type}."
302 orcestra_logger.error(errmsg)
303 raise WorkflowError(errmsg)
305 download_url = self.dataset_metadata.download_link
307 if download_url is None:
308 errmsg = f"Download URL for dataset {self.dataset_name} not found in {self.dataset_type}."
309 orcestra_logger.error(errmsg)
310 raise WorkflowError(errmsg)
312 import requests
314 with Progress() as progress:
315 task = progress.add_task(
316 f"Downloading {self.dataset_name} from {self.dataset_type}",
317 total=100,
318 )
319 temp_file = Path(f"{self.local_path()}.temp")
320 # Download the dataset to the local path
321 # with open(temp_file, "wb") as f:
322 with temp_file.open("wb") as f:
323 response = requests.get(download_url, stream=True)
324 # get total size based on header content-length
325 total_size = int(response.headers.get("content-length", 0))
327 if total_size == 0:
328 orcestra_logger.warning(
329 f"Could not determine total size of download for {self.dataset_name}."
330 )
331 task = progress.add_task(
332 f"Downloading {self.dataset_name} from {self.dataset_type}",
333 total=total_size,
334 )
335 for chunk in response.iter_content(chunk_size=1024):
336 f.write(chunk)
337 progress.update(task, advance=len(chunk))
339 # rename the temp file to the final file
340 Path(temp_file).rename(self.local_path())
341 progress.update(task, completed=100)
342 progress.stop()
344 orcestra_logger.debug(f"Downloaded dataset to {self.local_path()}")
346 # The following to methods are only required if the class inherits from
347 # StorageObjectReadWrite.
349 # @retry_decorator
350 # def store_object(self) -> None:
351 # # Ensure that the object is stored at the location specified by
352 # # self.local_path().
353 # ...
355 # @retry_decorator
356 # def remove(self) -> None:
357 # # Remove the object from the storage.
358 # ...
360 # The following to methods are only required if the class inherits from
361 # StorageObjectGlob.
363 # @retry_decorator
364 # def list_candidate_matches(self) -> Iterable[str]:
365 # """Return a list of candidate matches in the storage for the query."""
366 # # This is used by glob_wildcards() to find matches for wildcards in the query.
367 # # The method has to return concretized queries without any remaining wildcards.
368 # # Use snakemake_executor_plugins.io.get_constant_prefix(self.query) to get the
369 # # prefix of the query before the first wildcard.
370 # ...