Coverage for src/fastoai/routers/files.py: 43%
53 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
« prev ^ index » next coverage.py v7.6.8, created at 2024-12-06 09:34 +0800
1from shutil import copyfileobj
2from typing import Annotated, cast
4from fastapi import APIRouter, Form, Query, UploadFile
5from fastapi.responses import FileResponse
6from openai.types.file_deleted import FileDeleted
7from openai.types.file_list_params import FileListParams
8from openai.types.file_purpose import FilePurpose
9from sqlmodel import col, select
11from ..dependencies import SessionDependency, SettingsDependency
12from ..models import FileObject
13from ..pagination import AsyncCursorPage
14from ._utils import create_model_from
16router = APIRouter(tags=["Files"])
19@router.post("/files")
20async def upload_file(
21 file: UploadFile,
22 purpose: Annotated[FilePurpose, Form()],
23 settings: SettingsDependency,
24 session: SessionDependency,
25) -> FileObject:
26 file_object = FileObject.model_validate(
27 {
28 "bytes": file.size,
29 "filename": file.filename,
30 "purpose": purpose,
31 "status": "uploaded",
32 }
33 )
34 with (settings.upload_dir / file_object.id).open("wb") as f:
35 copyfileobj(file.file, f)
36 session.add(file_object)
37 await session.commit()
38 await session.refresh(file_object)
39 return FileObject.model_validate(file_object.model_dump())
42@router.get("/files")
43async def list_files(
44 q: Annotated[create_model_from(FileListParams), Query()], # type: ignore
45 session: SessionDependency,
46) -> AsyncCursorPage[FileObject]:
47 params = cast(FileListParams, q.model_dump())
48 statement = select(FileObject)
49 if purpose := params.get("purpose"):
50 statement = statement.where(FileObject.purpose == purpose)
51 if order := params.get("order"):
52 statement = statement.order_by(getattr(col(FileObject.created_at), order)())
53 if after := params.get("after"):
54 all_files = (await session.exec(statement)).all()
55 after_file = await session.get_one(FileObject, after)
56 offset = all_files.index(after_file) + 1
57 statement = statement.offset(offset)
58 if limit := params.get("limit"):
59 statement = statement.limit(limit)
60 files = (await session.exec(statement)).all()
61 return AsyncCursorPage[FileObject](
62 data=[FileObject.model_validate(file.model_dump()) for file in files]
63 )
66@router.get("/files/{file_id}")
67async def retrieve_file(
68 file_id: str,
69 session: SessionDependency,
70) -> FileObject:
71 file = await session.get_one(FileObject, file_id)
72 return FileObject.model_validate(file.model_dump())
75@router.get("/files/{file_id}/content")
76async def retrieve_file_content(
77 file_id: str,
78 settings: SettingsDependency,
79 session: SessionDependency,
80):
81 file = await session.get_one(FileObject, file_id)
82 return FileResponse(settings.upload_dir / file.id, filename=file.filename)
85@router.delete("/files/{file_id}", response_model=FileDeleted)
86async def delete_file(
87 file_id: str,
88 session: SessionDependency,
89):
90 file = await session.get_one(FileObject, file_id)
91 await session.delete(file)
92 await session.commit()
93 return FileDeleted(id=file_id, deleted=True, object="file")