from langchain_openai import ChatOpenAI from typing import Iterator from langchain_core.messages import BaseMessage, AIMessageChunk from openai import OpenAI from langchain_core.messages import AIMessage from langchain_core.outputs import ChatGeneration, ChatResult class CustomChatOpenAI(ChatOpenAI): """自定义的ChatOpenAI类,支持reasoning_content处理""" def _create_client(self) -> OpenAI: return OpenAI( base_url=self.openai_api_base, api_key=self.openai_api_key.get_secret_value() if self.openai_api_key else None, ) def _process_stream( self, messages, **kwargs, ) -> Iterator[AIMessageChunk]: client = self._create_client() stream = client.chat.completions.create( model=self.model_name, messages=messages, stream=True, **kwargs ) in_think_tag = False think_closed = False for chunk in stream: if not chunk.choices: continue delta = chunk.choices[0].delta if hasattr(delta, 'reasoning_content') and delta.reasoning_content: # 只允许出现一次 ... if not in_think_tag and not think_closed: yield AIMessageChunk(content="\n") in_think_tag = True if in_think_tag and not think_closed: yield AIMessageChunk(content=delta.reasoning_content) # 如果已经关闭了 think,则忽略后续 reasoning_content elif hasattr(delta, 'content') and delta.content: if in_think_tag and not think_closed: yield AIMessageChunk(content="\n\n") in_think_tag = False think_closed = True yield AIMessageChunk(content=delta.content) # 如果流结束时 还没关闭,补一个 if in_think_tag and not think_closed: yield AIMessageChunk(content="\n\n") def stream( self, messages, **kwargs, ) -> Iterator[AIMessageChunk]: message_dicts = [ {"role": "user" if msg.type == "human" else msg.type, "content": msg.content} for msg in messages ] return self._process_stream(message_dicts, **kwargs) def _generate( self, messages, **kwargs, ): client = self._create_client() message_dicts = [ {"role": "user" if msg.type == "human" else msg.type, "content": msg.content} for msg in messages ] response = client.chat.completions.create( model=self.model_name, messages=message_dicts, stream=False, **kwargs ) choice = response.choices[0] content = "" if hasattr(choice.message, "reasoning_content") and choice.message.reasoning_content: content += "\n\n" + choice.message.reasoning_content + "\n\n\n" if hasattr(choice.message, "content") and choice.message.content: content += choice.message.content message = AIMessage(content=content) return ChatResult(generations=[ChatGeneration(message=message)]) def invoke( self, input, **kwargs, ): if self.streaming: if isinstance(input, str): messages = [{"role": "user", "content": input}] else: messages = [ {"role": "user" if msg.type == "human" else msg.type, "content": msg.content} for msg in input ] return self._process_stream(messages, **kwargs) else: if isinstance(input, str): messages = [BaseMessage(content=input, type="human")] else: messages = input result = self._generate(messages, **kwargs) return result.generations[0].message