Coverage for src/snakemake_storage_plugin_orcestra/__init__.py: 80%

145 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-03-11 17:24 +0100

1import asyncio 

2from dataclasses import dataclass, field 

3from pathlib import Path 

4from typing import TYPE_CHECKING, Any, List, Optional 

5from urllib import parse 

6 

7from orcestradownloader.dataset_config import DATASET_CONFIG 

8from orcestradownloader.logging_config import logger as orcestra_logger 

9from orcestradownloader.managers import ( 

10 REGISTRY, 

11 DatasetManager, 

12 UnifiedDataManager, 

13) 

14from orcestradownloader.models.base import BaseModel 

15 

16# Raise errors that will not be handled within this plugin but thrown upwards to 

17# Snakemake and the user as WorkflowError. 

18from snakemake_interface_common.exceptions import WorkflowError # noqa: F401 

19from snakemake_interface_storage_plugins.io import IOCacheStorageInterface 

20from snakemake_interface_storage_plugins.settings import ( 

21 StorageProviderSettingsBase, 

22) 

23from snakemake_interface_storage_plugins.storage_object import ( 

24 StorageObjectRead, 

25 retry_decorator, 

26) 

27from snakemake_interface_storage_plugins.storage_provider import ( # noqa: F401 

28 ExampleQuery, 

29 Operation, 

30 QueryType, 

31 StorageProviderBase, 

32 StorageQueryValidationResult, 

33) 

34 

35if TYPE_CHECKING: 

36 from datetime import datetime 

37 

38 

39# Register all dataset managers automatically 

40for name, config in DATASET_CONFIG.items(): 

41 manager = DatasetManager( 

42 url=config.url, 

43 cache_file=config.cache_file, 

44 dataset_type=config.dataset_type, 

45 ) 

46 REGISTRY.register(name, manager) 

47 

48unified_manager = UnifiedDataManager(REGISTRY, force=True) 

49 

50 

51orcestra_logger.setLevel("WARNING") 

52 

53for handler in orcestra_logger.handlers[:]: 

54 orcestra_logger.removeHandler(handler) 

55 handler.close() 

56 

57 

58@dataclass 

59class StorageProviderSettings(StorageProviderSettingsBase): 

60 pass 

61 

62 

63# Required: 

64# Implementation of your storage provider 

65# This class can be empty as the one below. 

66# You can however use it to store global information or maintain e.g. a connection 

67# pool. 

68class StorageProvider(StorageProviderBase): 

69 # For compatibility with future changes, you should not overwrite the __init__ 

70 # method. Instead, use __post_init__ to set additional attributes and initialize 

71 # futher stuff. 

72 

73 def __post_init__(self) -> None: 

74 # This is optional and can be removed if not needed. 

75 # Alternatively, you can e.g. prepare a connection to your storage backend here. 

76 # and set additional attributes. 

77 pass 

78 

79 @classmethod 

80 def example_queries(cls) -> List[ExampleQuery]: 

81 """Return an example queries with description for this storage provider (at 

82 least one).""" 

83 return [ 

84 ExampleQuery( 

85 query="orcestra://pharmacosets/CCLE_2015", 

86 description="Download the CCLE 2015 dataset.", 

87 type=QueryType.INPUT, 

88 ) 

89 ] 

90 

91 def rate_limiter_key(self, query: str, operation: Operation) -> Any: # noqa: ANN401 

92 """Return a key for identifying a rate limiter given a query and an operation. 

93 Notes 

94 ----- 

95 Unused in orcestra-downloader 

96 """ 

97 return None 

98 

99 def default_max_requests_per_second(self) -> float: 

100 """Return the default maximum number of requests per second for this storage 

101 provider. 

102 Notes 

103 ----- 

104 Unused in orcestra-downloader 

105 """ 

106 return 0.0 

107 

108 def use_rate_limiter(self) -> bool: 

109 """Return False if no rate limiting is needed for this provider. 

110 Notes 

111 ----- 

112 Unused in orcestra-downloader 

113 """ 

114 return False 

115 

116 @classmethod 

117 def is_valid_query(cls, query: str) -> StorageQueryValidationResult: 

118 """Return whether the given query is valid for this storage provider.""" 

119 # Ensure that also queries containing wildcards (e.g. {sample}) are accepted 

120 # and considered valid. The wildcards will be resolved before the storage 

121 # object is actually used. 

122 datatypes = list(unified_manager.names()) 

123 errormsg = "" 

124 try: 

125 parsed_query = parse.urlparse(query) 

126 except Exception as e: 

127 errormsg = f"cannot be parsed as URL ({e})" 

128 else: 

129 if parsed_query.scheme != "orcestra": 

130 errormsg = ( 

131 f"Invalid scheme in query '{query}'." 

132 f"{parsed_query.scheme} should be 'orcestra'." 

133 ) 

134 elif parsed_query.netloc not in datatypes: 

135 errormsg = ( 

136 f"Invalid netloc in query '{query}'." 

137 f"{parsed_query.netloc} should be one of {datatypes}." 

138 ) 

139 elif not parsed_query.path: 

140 # remove the slash at the beginning 

141 dataset_name = parsed_query.path[1:] 

142 # check if there are still slashes in the path 

143 if "/" in dataset_name: 

144 errormsg = ( 

145 f"Invalid path in query '{query}'. " 

146 f"Format should follow" 

147 " 'orcestra://<datatype>/<dataset_name>' but got '{parsed_query}'." 

148 ) 

149 

150 if errormsg: 

151 orcestra_logger.error(errormsg) 

152 return StorageQueryValidationResult(query, False, errormsg) 

153 

154 return StorageQueryValidationResult(query, True, "") 

155 

156 

157# Required: 

158# Implementation of storage object. If certain methods cannot be supported by your 

159# storage (e.g. because it is read-only see 

160# snakemake-storage-http for comparison), remove the corresponding base classes 

161# from the list of inherited items. 

162class StorageObject(StorageObjectRead): 

163 # following attributes are inherited from StorageObjectRead: 

164 # query = query 

165 # keep_local = keep_local 

166 # retrieve = retrieve 

167 # provider = provider 

168 # _overwrite_local_path = None 

169 

170 dataset_type: str = field(init=False) 

171 dataset_name: str = field(init=False) 

172 manager: DatasetManager = field(init=False) 

173 dataset_metadata: BaseModel | None = field(init=False) 

174 

175 def __post_init__(self) -> None: 

176 # This is optional and can be removed if not needed. 

177 # Alternatively, you can e.g. prepare a connection to your storage backend here. 

178 # and set additional attributes. 

179 orcestra_logger.debug(f"StorageObject for query {self.query} created.") 

180 orcestra_logger.debug( 

181 f"Arguments: {self.keep_local=}, {self.retrieve=}, {self.provider=}, {self._overwrite_local_path=}" 

182 ) 

183 parsed = parse.urlparse(self.query) 

184 self.dataset_type = parsed.netloc 

185 self.dataset_name = parsed.path.split("/")[1] 

186 

187 orcestra_logger.debug( 

188 f"Dataset type: {self.dataset_type} and name: {self.dataset_name}" 

189 ) 

190 

191 # initialize manager for this datatype 

192 self.manager = unified_manager.registry.get_manager(self.dataset_type) 

193 

194 # use unified manager to fetch info 

195 try: 

196 # Get the current event loop or create one if it doesn't exist 

197 loop = asyncio.get_event_loop() 

198 if loop.is_running(): 

199 # If loop is already running, create a new future task and run it directly 

200 asyncio.create_task( 

201 unified_manager.fetch_by_name( 

202 self.dataset_type, force=True 

203 ) 

204 ) 

205 else: 

206 # If no loop is running, run the coroutine directly 

207 loop.run_until_complete( 

208 unified_manager.fetch_by_name( 

209 self.dataset_type, force=True 

210 ) 

211 ) 

212 except RuntimeError: 

213 # Fallback to creating a new loop if needed 

214 asyncio.run( 

215 unified_manager.fetch_by_name(self.dataset_type, force=True) 

216 ) 

217 

218 try: 

219 self.dataset_metadata = self.manager[self.dataset_name] 

220 except ValueError: 

221 self.dataset_metadata = None 

222 

223 async def inventory(self, cache: IOCacheStorageInterface) -> None: 

224 """From this file, try to find as much existence and modification date 

225 information as possible. Only retrieve that information that comes for free 

226 given the current object. 

227 """ 

228 # This is optional and can be left as is 

229 

230 # If this is implemented in a storage object, results have to be stored in 

231 # the given IOCache object, using self.cache_key() as key. 

232 # Optionally, this can take a custom local suffix, needed e.g. when you want 

233 # to cache more items than the current query: self.cache_key(local_suffix=...) 

234 pass 

235 

236 def get_inventory_parent(self) -> Optional[str]: 

237 """Return the parent directory of this object.""" 

238 # this is optional and can be left as is 

239 return None 

240 

241 def local_suffix(self) -> str: 

242 """Return a unique suffix for the local path, determined from self.query.""" 

243 parsed = parse.urlparse(self.query) 

244 return f"{parsed.netloc}{parsed.path}.RDS" 

245 

246 def cleanup(self) -> None: 

247 """Perform local cleanup of any remainders of the storage object.""" 

248 # self.local_path() should not be removed, as this is taken care of by 

249 # Snakemake. 

250 ... 

251 

252 # Fallible methods should implement some retry logic. 

253 # The easiest way to do this (but not the only one) is to use the retry_decorator 

254 # provided by snakemake-interface-storage-plugins. 

255 @retry_decorator 

256 def exists(self) -> bool: 

257 if self.dataset_metadata: 

258 return True 

259 from difflib import get_close_matches 

260 

261 dataset_names = self.manager.names() 

262 

263 if self.dataset_name not in dataset_names: 

264 errmsg = ( 

265 f"Dataset {self.dataset_name} not found in {self.dataset_type}." 

266 f"Did you mean one of {get_close_matches(self.dataset_name, dataset_names)}?" 

267 ) 

268 orcestra_logger.error(errmsg) 

269 return False 

270 

271 @retry_decorator 

272 def mtime(self) -> float: 

273 # return the modification time 

274 if self.dataset_metadata is None: 

275 # return infinity if no date is available 

276 return float("-inf") 

277 

278 created_date: datetime | None = self.dataset_metadata.date_created 

279 

280 if created_date is None: 

281 # return infinity if no date is available 

282 return float("-inf") 

283 return float(created_date.timestamp()) 

284 

285 @retry_decorator 

286 def size(self) -> int: 

287 # return the size in bytes 

288 return 0 

289 

290 @retry_decorator 

291 def retrieve_object(self) -> None: 

292 # Ensure that the object is accessible locally under self.local_path() 

293 directory_path = Path(self.local_path()).parent 

294 if not directory_path.exists(): 

295 directory_path.mkdir(parents=True) 

296 from rich.progress import ( 

297 Progress, 

298 ) 

299 

300 if self.dataset_metadata is None: 

301 errmsg = f"Dataset {self.dataset_name} not found in {self.dataset_type}." 

302 orcestra_logger.error(errmsg) 

303 raise WorkflowError(errmsg) 

304 

305 download_url = self.dataset_metadata.download_link 

306 

307 if download_url is None: 

308 errmsg = f"Download URL for dataset {self.dataset_name} not found in {self.dataset_type}." 

309 orcestra_logger.error(errmsg) 

310 raise WorkflowError(errmsg) 

311 

312 import requests 

313 

314 with Progress() as progress: 

315 task = progress.add_task( 

316 f"Downloading {self.dataset_name} from {self.dataset_type}", 

317 total=100, 

318 ) 

319 temp_file = Path(f"{self.local_path()}.temp") 

320 # Download the dataset to the local path 

321 # with open(temp_file, "wb") as f: 

322 with temp_file.open("wb") as f: 

323 response = requests.get(download_url, stream=True) 

324 # get total size based on header content-length 

325 total_size = int(response.headers.get("content-length", 0)) 

326 

327 if total_size == 0: 

328 orcestra_logger.warning( 

329 f"Could not determine total size of download for {self.dataset_name}." 

330 ) 

331 task = progress.add_task( 

332 f"Downloading {self.dataset_name} from {self.dataset_type}", 

333 total=total_size, 

334 ) 

335 for chunk in response.iter_content(chunk_size=1024): 

336 f.write(chunk) 

337 progress.update(task, advance=len(chunk)) 

338 

339 # rename the temp file to the final file 

340 Path(temp_file).rename(self.local_path()) 

341 progress.update(task, completed=100) 

342 progress.stop() 

343 

344 orcestra_logger.debug(f"Downloaded dataset to {self.local_path()}") 

345 

346 # The following to methods are only required if the class inherits from 

347 # StorageObjectReadWrite. 

348 

349 # @retry_decorator 

350 # def store_object(self) -> None: 

351 # # Ensure that the object is stored at the location specified by 

352 # # self.local_path(). 

353 # ... 

354 

355 # @retry_decorator 

356 # def remove(self) -> None: 

357 # # Remove the object from the storage. 

358 # ... 

359 

360 # The following to methods are only required if the class inherits from 

361 # StorageObjectGlob. 

362 

363 # @retry_decorator 

364 # def list_candidate_matches(self) -> Iterable[str]: 

365 # """Return a list of candidate matches in the storage for the query.""" 

366 # # This is used by glob_wildcards() to find matches for wildcards in the query. 

367 # # The method has to return concretized queries without any remaining wildcards. 

368 # # Use snakemake_executor_plugins.io.get_constant_prefix(self.query) to get the 

369 # # prefix of the query before the first wildcard. 

370 # ...