Coverage for src\agents_sdk_models\llm.py: 93%
101 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-04 17:38 +0900
« prev ^ index » next coverage.py v7.8.0, created at 2025-06-04 17:38 +0900
1from typing import Literal, Optional, Any, List
2from agents.models.interface import Model
3from agents import OpenAIChatCompletionsModel # Import from agents library
4from agents import set_tracing_disabled # Import set_tracing_disabled for tracing control
5# English: Import OpenAI client
6# 日本語: OpenAI クライアントをインポート
7from openai import AsyncOpenAI
8from agents.models.openai_responses import OpenAIResponsesModel
9# English: Import HTTP client for API requests
10# 日本語: API リクエスト用の HTTP クライアントをインポート
11import httpx
12import asyncio
14from .anthropic import ClaudeModel
15from .gemini import GeminiModel
16from .ollama import OllamaModel
18# Define the provider type hint
19ProviderType = Literal["openai", "google", "anthropic", "ollama"]
21import os
23def get_llm(
24 model: Optional[str] = None,
25 provider: Optional[ProviderType] = None,
26 temperature: float = 0.3,
27 api_key: Optional[str] = None,
28 base_url: Optional[str] = None,
29 thinking: bool = False,
30 **kwargs: Any,
31) -> Model:
32 """
33 Factory function to get an instance of a language model based on the provider.
35 English:
36 Factory function to get an instance of a language model based on the provider.
38 日本語:
39 プロバイダーに基づいて言語モデルのインスタンスを取得するファクトリ関数。
41 Args:
42 provider (ProviderType): The LLM provider ("openai", "google", "anthropic", "ollama"). Defaults to "openai".
43 LLM プロバイダー ("openai", "google", "anthropic", "ollama")。デフォルトは "openai"。
44 model (Optional[str]): The specific model name for the provider. If None, uses the default for the provider.
45 プロバイダー固有のモデル名。None の場合、プロバイダーのデフォルトを使用します。
46 temperature (float): Sampling temperature. Defaults to 0.3.
47 サンプリング温度。デフォルトは 0.3。
48 api_key (Optional[str]): API key for the provider, if required.
49 プロバイダーの API キー (必要な場合)。
50 base_url (Optional[str]): Base URL for the provider's API, if needed (e.g., for self-hosted Ollama or OpenAI-compatible APIs).
51 プロバイダー API のベース URL (必要な場合、例: セルフホストの Ollama や OpenAI 互換 API)。
52 thinking (bool): Enable thinking mode for Claude models. Defaults to False.
53 Claude モデルの思考モードを有効にするか。デフォルトは False。
54 tracing (bool): Whether to enable tracing for the Agents SDK. Defaults to False.
55 Agents SDK のトレーシングを有効化するか。デフォルトは False。
56 **kwargs (Any): Additional keyword arguments to pass to the model constructor.
57 モデルのコンストラクタに渡す追加のキーワード引数。
59 Returns:
60 Model: An instance of the appropriate language model class.
61 適切な言語モデルクラスのインスタンス。
63 Raises:
64 ValueError: If an unsupported provider is specified.
65 サポートされていないプロバイダーが指定された場合。
66 """
67 # English: Configure OpenAI Agents SDK tracing
68 # 日本語: OpenAI Agents SDK のトレーシングを設定する
69 # set_tracing_disabled(not tracing)
72 if model is None:
73 model = os.environ.get("LLM_MODEL", "gpt-4o-mini")
75 def get_provider_canditate(model: str) -> ProviderType:
76 if "gpt" in model:
77 return "openai"
78 if "o3" in model or "o4" in model:
79 return "openai"
80 elif "gemini" in model:
81 return "google"
82 elif "claude" in model:
83 return "anthropic"
84 else:
85 return "ollama"
87 if provider is None:
88 provider = get_provider_canditate(model)
90 if provider == "openai":
91 # Use the standard OpenAI model from the agents library
92 # agentsライブラリの標準 OpenAI モデルを使用
93 openai_kwargs = kwargs.copy()
95 # English: Prepare arguments for OpenAI client and model
96 # 日本語: OpenAI クライアントとモデルの引数を準備
97 client_args = {}
98 model_args = {}
100 # English: Set API key for client
101 # 日本語: クライアントに API キーを設定
102 if api_key:
103 client_args['api_key'] = api_key
104 # English: Set base URL for client
105 # 日本語: クライアントにベース URL を設定
106 if base_url:
107 client_args['base_url'] = base_url
109 # English: Set model name for model constructor
110 # 日本語: モデルコンストラクタにモデル名を設定
111 model_args['model'] = model if model else "gpt-4o-mini" # Default to gpt-4o-mini
113 # English: Temperature is likely handled by the runner or set post-init,
114 # English: so remove it from constructor args.
115 # 日本語: temperature はランナーによって処理されるか、初期化後に設定される可能性が高いため、
116 # 日本語: コンストラクタ引数から削除します。
117 # model_args['temperature'] = temperature # Removed based on TypeError
119 # English: Add any other relevant kwargs passed in, EXCLUDING temperature
120 # 日本語: 渡された他の関連する kwargs を追加 (temperature を除く)
121 # Example: max_tokens, etc. Filter out args meant for the client.
122 # 例: max_tokens など。クライアント向けの引数を除外します。
123 for key, value in kwargs.items():
124 # English: Exclude client args, thinking, temperature, and tracing
125 # 日本語: クライアント引数、thinking、temperature、tracing を除外
126 if key not in ['api_key', 'base_url', 'thinking', 'temperature', 'tracing']:
127 model_args[key] = value
129 # English: Remove 'thinking' as it's not used by OpenAI model
130 # 日本語: OpenAI モデルでは使用されないため 'thinking' を削除
131 model_args.pop('thinking', None)
133 # English: Instantiate the OpenAI client
134 # 日本語: OpenAI クライアントをインスタンス化
135 openai_client = AsyncOpenAI(**client_args)
137 # English: Instantiate and return the model, passing the client and model args
138 # 日本語: クライアントとモデル引数を渡してモデルをインスタンス化して返す
139 return OpenAIResponsesModel(
140 openai_client=openai_client,
141 **model_args
142 )
143 elif provider == "google":
144 gemini_kwargs = kwargs.copy()
145 if model:
146 gemini_kwargs['model'] = model
147 # thinking is not used by GeminiModel
148 gemini_kwargs.pop('thinking', None)
149 return GeminiModel(
150 temperature=temperature,
151 api_key=api_key,
152 base_url=base_url, # Although Gemini doesn't typically use base_url, pass it if provided
153 **gemini_kwargs
154 )
155 elif provider == "anthropic":
156 claude_kwargs = kwargs.copy()
157 if model:
158 claude_kwargs['model'] = model
159 return ClaudeModel(
160 temperature=temperature,
161 api_key=api_key,
162 base_url=base_url, # Although Claude doesn't typically use base_url, pass it if provided
163 thinking=thinking,
164 **claude_kwargs
165 )
166 elif provider == "ollama":
167 ollama_kwargs = kwargs.copy()
168 if model:
169 ollama_kwargs['model'] = model
170 # thinking is not used by OllamaModel
171 ollama_kwargs.pop('thinking', None)
172 return OllamaModel(
173 temperature=temperature,
174 base_url=base_url,
175 api_key=api_key, # Although Ollama doesn't typically use api_key, pass it if provided
176 **ollama_kwargs
177 )
178 else:
179 raise ValueError(f"Unsupported provider: {provider}. Must be one of {ProviderType.__args__}")
181async def get_available_models_async(
182 providers: List[ProviderType],
183 ollama_base_url: Optional[str] = None
184) -> dict[str, List[str]]:
185 """
186 Get available model names for specified providers.
188 English:
189 Get available model names for specified providers.
191 日本語:
192 指定されたプロバイダーの利用可能なモデル名を取得します。
194 Args:
195 providers (List[ProviderType]): List of providers to get models for.
196 モデルを取得するプロバイダーのリスト。
197 ollama_base_url (Optional[str]): Base URL for Ollama API. If None, uses environment variable or default.
198 Ollama API のベース URL。None の場合、環境変数またはデフォルトを使用。
200 Returns:
201 dict[str, List[str]]: Dictionary mapping provider names to lists of available models.
202 プロバイダー名と利用可能なモデルのリストのマッピング辞書。
204 Raises:
205 ValueError: If an unsupported provider is specified.
206 サポートされていないプロバイダーが指定された場合。
207 httpx.RequestError: If there's an error connecting to the Ollama API.
208 Ollama API への接続エラーが発生した場合。
209 """
210 result = {}
212 for provider in providers:
213 if provider == "openai":
214 # English: OpenAI models - latest available models
215 # 日本語: OpenAI モデル - 最新の利用可能なモデル
216 result["openai"] = [
217 "gpt-4o",
218 "gpt-4o-mini",
219 "gpt-4.1",
220 "o3",
221 "o4-mini"
222 ]
223 elif provider == "google":
224 # English: Google Gemini models - latest 2.5 series models
225 # 日本語: Google Gemini モデル - 最新の 2.5 シリーズモデル
226 result["google"] = [
227 "gemini-2.5-pro",
228 "gemini-2.5-flash"
229 ]
230 elif provider == "anthropic":
231 # English: Anthropic Claude models - latest Claude-4 series models
232 # 日本語: Anthropic Claude モデル - 最新の Claude-4 シリーズモデル
233 result["anthropic"] = [
234 "claude-opus-4",
235 "claude-sonnet-4"
236 ]
237 elif provider == "ollama":
238 # English: Get Ollama base URL from parameter, environment variable, or default
239 # 日本語: パラメータ、環境変数、またはデフォルトから Ollama ベース URL を取得
240 if ollama_base_url is None:
241 ollama_base_url = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434")
243 try:
244 # English: Fetch available models from Ollama API
245 # 日本語: Ollama API から利用可能なモデルを取得
246 async with httpx.AsyncClient() as client:
247 response = await client.get(f"{ollama_base_url}/api/tags")
248 response.raise_for_status()
250 # English: Parse the response to extract model names
251 # 日本語: レスポンスを解析してモデル名を抽出
252 data = response.json()
253 models = []
254 if "models" in data:
255 for model_info in data["models"]:
256 if "name" in model_info:
257 models.append(model_info["name"])
259 result["ollama"] = models
261 except httpx.RequestError as e:
262 # English: If connection fails, return empty list with error info
263 # 日本語: 接続に失敗した場合、エラー情報と共に空のリストを返す
264 result["ollama"] = []
265 print(f"Warning: Could not connect to Ollama at {ollama_base_url}: {e}")
266 except Exception as e:
267 # English: Handle other errors
268 # 日本語: その他のエラーを処理
269 result["ollama"] = []
270 print(f"Warning: Error fetching Ollama models: {e}")
271 else:
272 raise ValueError(f"Unsupported provider: {provider}. Must be one of {ProviderType.__args__}")
274 return result
276def get_available_models(
277 providers: List[ProviderType],
278 ollama_base_url: Optional[str] = None
279) -> dict[str, List[str]]:
280 """
281 Get available model names for specified providers (synchronous version).
283 English:
284 Get available model names for specified providers (synchronous version).
286 日本語:
287 指定されたプロバイダーの利用可能なモデル名を取得します(同期版)。
289 Args:
290 providers (List[ProviderType]): List of providers to get models for.
291 モデルを取得するプロバイダーのリスト。
292 ollama_base_url (Optional[str]): Base URL for Ollama API. If None, uses environment variable or default.
293 Ollama API のベース URL。None の場合、環境変数またはデフォルトを使用。
295 Returns:
296 dict[str, List[str]]: Dictionary mapping provider names to lists of available models.
297 プロバイダー名と利用可能なモデルのリストのマッピング辞書。
298 """
299 try:
300 # English: Try to get the current event loop
301 # 日本語: 現在のイベントループを取得しようとする
302 loop = asyncio.get_running_loop()
303 # English: If we're in a running loop, we need to handle this differently
304 # 日本語: 実行中のループ内にいる場合、異なる方法で処理する必要がある
305 import concurrent.futures
306 with concurrent.futures.ThreadPoolExecutor() as executor:
307 future = executor.submit(asyncio.run, get_available_models_async(providers, ollama_base_url))
308 return future.result()
309 except RuntimeError:
310 # English: No running event loop, safe to use asyncio.run()
311 # 日本語: 実行中のイベントループがない場合、asyncio.run() を安全に使用
312 return asyncio.run(get_available_models_async(providers, ollama_base_url))