Coverage for src/refinire/core/llm.py: 93%

101 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-06-15 18:51 +0900

1from typing import Literal, Optional, Any, List 

2from agents.models.interface import Model 

3from agents import OpenAIChatCompletionsModel # Import from agents library 

4from agents import set_tracing_disabled # Import set_tracing_disabled for tracing control 

5# English: Import OpenAI client 

6# 日本語: OpenAI クライアントをインポート 

7from openai import AsyncOpenAI 

8from agents.models.openai_responses import OpenAIResponsesModel 

9# English: Import HTTP client for API requests 

10# 日本語: API リクエスト用の HTTP クライアントをインポート 

11import httpx 

12import asyncio 

13 

14from .anthropic import ClaudeModel 

15from .gemini import GeminiModel 

16from .ollama import OllamaModel 

17 

18# Define the provider type hint 

19ProviderType = Literal["openai", "google", "anthropic", "ollama"] 

20 

21import os 

22 

23def get_llm( 

24 model: Optional[str] = None, 

25 provider: Optional[ProviderType] = None, 

26 temperature: float = 0.3, 

27 api_key: Optional[str] = None, 

28 base_url: Optional[str] = None, 

29 thinking: bool = False, 

30 **kwargs: Any, 

31) -> Model: 

32 """ 

33 Factory function to get an instance of a language model based on the provider. 

34 

35 English: 

36 Factory function to get an instance of a language model based on the provider. 

37 

38 日本語: 

39 プロバイダーに基づいて言語モデルのインスタンスを取得するファクトリ関数。 

40 

41 Args: 

42 provider (ProviderType): The LLM provider ("openai", "google", "anthropic", "ollama"). Defaults to "openai". 

43 LLM プロバイダー ("openai", "google", "anthropic", "ollama")。デフォルトは "openai"。 

44 model (Optional[str]): The specific model name for the provider. If None, uses the default for the provider. 

45 プロバイダー固有のモデル名。None の場合、プロバイダーのデフォルトを使用します。 

46 temperature (float): Sampling temperature. Defaults to 0.3. 

47 サンプリング温度。デフォルトは 0.3。 

48 api_key (Optional[str]): API key for the provider, if required. 

49 プロバイダーの API キー (必要な場合)。 

50 base_url (Optional[str]): Base URL for the provider's API, if needed (e.g., for self-hosted Ollama or OpenAI-compatible APIs). 

51 プロバイダー API のベース URL (必要な場合、例: セルフホストの Ollama や OpenAI 互換 API)。 

52 thinking (bool): Enable thinking mode for Claude models. Defaults to False. 

53 Claude モデルの思考モードを有効にするか。デフォルトは False。 

54 tracing (bool): Whether to enable tracing for the Agents SDK. Defaults to False. 

55 Agents SDK のトレーシングを有効化するか。デフォルトは False。 

56 **kwargs (Any): Additional keyword arguments to pass to the model constructor. 

57 モデルのコンストラクタに渡す追加のキーワード引数。 

58 

59 Returns: 

60 Model: An instance of the appropriate language model class. 

61 適切な言語モデルクラスのインスタンス。 

62 

63 Raises: 

64 ValueError: If an unsupported provider is specified. 

65 サポートされていないプロバイダーが指定された場合。 

66 """ 

67 # English: Configure OpenAI Agents SDK tracing 

68 # 日本語: OpenAI Agents SDK のトレーシングを設定する 

69 # set_tracing_disabled(not tracing) 

70 

71 

72 if model is None: 

73 model = os.environ.get("REFINIRE_DEFAULT_LLM_MODEL", "gpt-4o-mini") 

74 

75 def get_provider_canditate(model: str) -> ProviderType: 

76 if "gpt" in model: 

77 return "openai" 

78 if "o3" in model or "o4" in model: 

79 return "openai" 

80 elif "gemini" in model: 

81 return "google" 

82 elif "claude" in model: 

83 return "anthropic" 

84 else: 

85 return "ollama" 

86 

87 if provider is None: 

88 provider = get_provider_canditate(model) 

89 

90 if provider == "openai": 

91 # Use the standard OpenAI model from the agents library 

92 # agentsライブラリの標準 OpenAI モデルを使用 

93 openai_kwargs = kwargs.copy() 

94 

95 # English: Prepare arguments for OpenAI client and model 

96 # 日本語: OpenAI クライアントとモデルの引数を準備 

97 client_args = {} 

98 model_args = {} 

99 

100 # English: Set API key for client 

101 # 日本語: クライアントに API キーを設定 

102 if api_key: 

103 client_args['api_key'] = api_key 

104 # English: Set base URL for client 

105 # 日本語: クライアントにベース URL を設定 

106 if base_url: 

107 client_args['base_url'] = base_url 

108 

109 # English: Set model name for model constructor 

110 # 日本語: モデルコンストラクタにモデル名を設定 

111 model_args['model'] = model if model else "gpt-4o-mini" # Default to gpt-4o-mini 

112 

113 # English: Temperature is likely handled by the runner or set post-init, 

114 # English: so remove it from constructor args. 

115 # 日本語: temperature はランナーによって処理されるか、初期化後に設定される可能性が高いため、 

116 # 日本語: コンストラクタ引数から削除します。 

117 # model_args['temperature'] = temperature # Removed based on TypeError 

118 

119 # English: Add any other relevant kwargs passed in, EXCLUDING temperature 

120 # 日本語: 渡された他の関連する kwargs を追加 (temperature を除く) 

121 # Example: max_tokens, etc. Filter out args meant for the client. 

122 # 例: max_tokens など。クライアント向けの引数を除外します。 

123 for key, value in kwargs.items(): 

124 # English: Exclude client args, thinking, temperature, and tracing 

125 # 日本語: クライアント引数、thinking、temperature、tracing を除外 

126 if key not in ['api_key', 'base_url', 'thinking', 'temperature', 'tracing']: 

127 model_args[key] = value 

128 

129 # English: Remove 'thinking' as it's not used by OpenAI model 

130 # 日本語: OpenAI モデルでは使用されないため 'thinking' を削除 

131 model_args.pop('thinking', None) 

132 

133 # English: Instantiate the OpenAI client 

134 # 日本語: OpenAI クライアントをインスタンス化 

135 openai_client = AsyncOpenAI(**client_args) 

136 

137 # English: Instantiate and return the model, passing the client and model args 

138 # 日本語: クライアントとモデル引数を渡してモデルをインスタンス化して返す 

139 return OpenAIResponsesModel( 

140 openai_client=openai_client, 

141 **model_args 

142 ) 

143 elif provider == "google": 

144 gemini_kwargs = kwargs.copy() 

145 if model: 

146 gemini_kwargs['model'] = model 

147 # thinking is not used by GeminiModel 

148 gemini_kwargs.pop('thinking', None) 

149 return GeminiModel( 

150 temperature=temperature, 

151 api_key=api_key, 

152 base_url=base_url, # Although Gemini doesn't typically use base_url, pass it if provided 

153 **gemini_kwargs 

154 ) 

155 elif provider == "anthropic": 

156 claude_kwargs = kwargs.copy() 

157 if model: 

158 claude_kwargs['model'] = model 

159 return ClaudeModel( 

160 temperature=temperature, 

161 api_key=api_key, 

162 base_url=base_url, # Although Claude doesn't typically use base_url, pass it if provided 

163 thinking=thinking, 

164 **claude_kwargs 

165 ) 

166 elif provider == "ollama": 

167 ollama_kwargs = kwargs.copy() 

168 if model: 

169 ollama_kwargs['model'] = model 

170 # thinking is not used by OllamaModel 

171 ollama_kwargs.pop('thinking', None) 

172 return OllamaModel( 

173 temperature=temperature, 

174 base_url=base_url, 

175 api_key=api_key, # Although Ollama doesn't typically use api_key, pass it if provided 

176 **ollama_kwargs 

177 ) 

178 else: 

179 raise ValueError(f"Unsupported provider: {provider}. Must be one of {ProviderType.__args__}") 

180 

181async def get_available_models_async( 

182 providers: List[ProviderType], 

183 ollama_base_url: Optional[str] = None 

184) -> dict[str, List[str]]: 

185 """ 

186 Get available model names for specified providers. 

187  

188 English: 

189 Get available model names for specified providers. 

190  

191 日本語: 

192 指定されたプロバイダーの利用可能なモデル名を取得します。 

193  

194 Args: 

195 providers (List[ProviderType]): List of providers to get models for. 

196 モデルを取得するプロバイダーのリスト。 

197 ollama_base_url (Optional[str]): Base URL for Ollama API. If None, uses environment variable or default. 

198 Ollama API のベース URL。None の場合、環境変数またはデフォルトを使用。 

199  

200 Returns: 

201 dict[str, List[str]]: Dictionary mapping provider names to lists of available models. 

202 プロバイダー名と利用可能なモデルのリストのマッピング辞書。 

203  

204 Raises: 

205 ValueError: If an unsupported provider is specified. 

206 サポートされていないプロバイダーが指定された場合。 

207 httpx.RequestError: If there's an error connecting to the Ollama API. 

208 Ollama API への接続エラーが発生した場合。 

209 """ 

210 result = {} 

211 

212 for provider in providers: 

213 if provider == "openai": 

214 # English: OpenAI models - latest available models 

215 # 日本語: OpenAI モデル - 最新の利用可能なモデル 

216 result["openai"] = [ 

217 "gpt-4o", 

218 "gpt-4o-mini", 

219 "gpt-4.1", 

220 "o3", 

221 "o4-mini" 

222 ] 

223 elif provider == "google": 

224 # English: Google Gemini models - latest 2.5 series models 

225 # 日本語: Google Gemini モデル - 最新の 2.5 シリーズモデル 

226 result["google"] = [ 

227 "gemini-2.5-pro", 

228 "gemini-2.5-flash" 

229 ] 

230 elif provider == "anthropic": 

231 # English: Anthropic Claude models - latest Claude-4 series models 

232 # 日本語: Anthropic Claude モデル - 最新の Claude-4 シリーズモデル 

233 result["anthropic"] = [ 

234 "claude-opus-4", 

235 "claude-sonnet-4" 

236 ] 

237 elif provider == "ollama": 

238 # English: Get Ollama base URL from parameter, environment variable, or default 

239 # 日本語: パラメータ、環境変数、またはデフォルトから Ollama ベース URL を取得 

240 if ollama_base_url is None: 

241 ollama_base_url = os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434") 

242 

243 try: 

244 # English: Fetch available models from Ollama API 

245 # 日本語: Ollama API から利用可能なモデルを取得 

246 async with httpx.AsyncClient() as client: 

247 response = await client.get(f"{ollama_base_url}/api/tags") 

248 response.raise_for_status() 

249 

250 # English: Parse the response to extract model names 

251 # 日本語: レスポンスを解析してモデル名を抽出 

252 data = response.json() 

253 models = [] 

254 if "models" in data: 

255 for model_info in data["models"]: 

256 if "name" in model_info: 

257 models.append(model_info["name"]) 

258 

259 result["ollama"] = models 

260 

261 except httpx.RequestError as e: 

262 # English: If connection fails, return empty list with error info 

263 # 日本語: 接続に失敗した場合、エラー情報と共に空のリストを返す 

264 result["ollama"] = [] 

265 print(f"Warning: Could not connect to Ollama at {ollama_base_url}: {e}") 

266 except Exception as e: 

267 # English: Handle other errors 

268 # 日本語: その他のエラーを処理 

269 result["ollama"] = [] 

270 print(f"Warning: Error fetching Ollama models: {e}") 

271 else: 

272 raise ValueError(f"Unsupported provider: {provider}. Must be one of {ProviderType.__args__}") 

273 

274 return result 

275 

276def get_available_models( 

277 providers: List[ProviderType], 

278 ollama_base_url: Optional[str] = None 

279) -> dict[str, List[str]]: 

280 """ 

281 Get available model names for specified providers (synchronous version). 

282  

283 English: 

284 Get available model names for specified providers (synchronous version). 

285  

286 日本語: 

287 指定されたプロバイダーの利用可能なモデル名を取得します(同期版)。 

288  

289 Args: 

290 providers (List[ProviderType]): List of providers to get models for. 

291 モデルを取得するプロバイダーのリスト。 

292 ollama_base_url (Optional[str]): Base URL for Ollama API. If None, uses environment variable or default. 

293 Ollama API のベース URL。None の場合、環境変数またはデフォルトを使用。 

294  

295 Returns: 

296 dict[str, List[str]]: Dictionary mapping provider names to lists of available models. 

297 プロバイダー名と利用可能なモデルのリストのマッピング辞書。 

298 """ 

299 try: 

300 # English: Try to get the current event loop 

301 # 日本語: 現在のイベントループを取得しようとする 

302 loop = asyncio.get_running_loop() 

303 # English: If we're in a running loop, we need to handle this differently 

304 # 日本語: 実行中のループ内にいる場合、異なる方法で処理する必要がある 

305 import concurrent.futures 

306 with concurrent.futures.ThreadPoolExecutor() as executor: 

307 future = executor.submit(asyncio.run, get_available_models_async(providers, ollama_base_url)) 

308 return future.result() 

309 except RuntimeError: 

310 # English: No running event loop, safe to use asyncio.run() 

311 # 日本語: 実行中のイベントループがない場合、asyncio.run() を安全に使用 

312 return asyncio.run(get_available_models_async(providers, ollama_base_url))