Coverage for src/fastoai/routers/files.py: 43%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-12-06 09:34 +0800

1from shutil import copyfileobj 

2from typing import Annotated, cast 

3 

4from fastapi import APIRouter, Form, Query, UploadFile 

5from fastapi.responses import FileResponse 

6from openai.types.file_deleted import FileDeleted 

7from openai.types.file_list_params import FileListParams 

8from openai.types.file_purpose import FilePurpose 

9from sqlmodel import col, select 

10 

11from ..dependencies import SessionDependency, SettingsDependency 

12from ..models import FileObject 

13from ..pagination import AsyncCursorPage 

14from ._utils import create_model_from 

15 

16router = APIRouter(tags=["Files"]) 

17 

18 

19@router.post("/files") 

20async def upload_file( 

21 file: UploadFile, 

22 purpose: Annotated[FilePurpose, Form()], 

23 settings: SettingsDependency, 

24 session: SessionDependency, 

25) -> FileObject: 

26 file_object = FileObject.model_validate( 

27 { 

28 "bytes": file.size, 

29 "filename": file.filename, 

30 "purpose": purpose, 

31 "status": "uploaded", 

32 } 

33 ) 

34 with (settings.upload_dir / file_object.id).open("wb") as f: 

35 copyfileobj(file.file, f) 

36 session.add(file_object) 

37 await session.commit() 

38 await session.refresh(file_object) 

39 return FileObject.model_validate(file_object.model_dump()) 

40 

41 

42@router.get("/files") 

43async def list_files( 

44 q: Annotated[create_model_from(FileListParams), Query()], # type: ignore 

45 session: SessionDependency, 

46) -> AsyncCursorPage[FileObject]: 

47 params = cast(FileListParams, q.model_dump()) 

48 statement = select(FileObject) 

49 if purpose := params.get("purpose"): 

50 statement = statement.where(FileObject.purpose == purpose) 

51 if order := params.get("order"): 

52 statement = statement.order_by(getattr(col(FileObject.created_at), order)()) 

53 if after := params.get("after"): 

54 all_files = (await session.exec(statement)).all() 

55 after_file = await session.get_one(FileObject, after) 

56 offset = all_files.index(after_file) + 1 

57 statement = statement.offset(offset) 

58 if limit := params.get("limit"): 

59 statement = statement.limit(limit) 

60 files = (await session.exec(statement)).all() 

61 return AsyncCursorPage[FileObject]( 

62 data=[FileObject.model_validate(file.model_dump()) for file in files] 

63 ) 

64 

65 

66@router.get("/files/{file_id}") 

67async def retrieve_file( 

68 file_id: str, 

69 session: SessionDependency, 

70) -> FileObject: 

71 file = await session.get_one(FileObject, file_id) 

72 return FileObject.model_validate(file.model_dump()) 

73 

74 

75@router.get("/files/{file_id}/content") 

76async def retrieve_file_content( 

77 file_id: str, 

78 settings: SettingsDependency, 

79 session: SessionDependency, 

80): 

81 file = await session.get_one(FileObject, file_id) 

82 return FileResponse(settings.upload_dir / file.id, filename=file.filename) 

83 

84 

85@router.delete("/files/{file_id}", response_model=FileDeleted) 

86async def delete_file( 

87 file_id: str, 

88 session: SessionDependency, 

89): 

90 file = await session.get_one(FileObject, file_id) 

91 await session.delete(file) 

92 await session.commit() 

93 return FileDeleted(id=file_id, deleted=True, object="file")