python-sdk icon indicating copy to clipboard operation
python-sdk copied to clipboard

Session always close after a few minutes.

Open johnson7788 opened this issue 8 months ago • 0 comments

Hi, I created two functions to call the MCP server, following the official examples. I even specifically set sse_read_timeout=None for the SSE connection. However, both functions end up throwing Session closed: ClosedResourceError() after a few minutes. I have to restart them manually. Is there a way to keep their sessions open continuously?

pip freeze | grep mcp

mcp==1.9.0

class SSEMCPClient:
    """Implementation for a SSE-based MCP server."""

    def __init__(self, server_name: str, url: str):
        self.server_name = server_name
        self.url = url
        self.tools = []
        self._streams_context = None
        self._session_context = None
        self.session = None

    async def start(self):
        try:
            self._streams_context = sse_client(url=self.url, sse_read_timeout=None)
            streams = await self._streams_context.__aenter__()

            self._session_context = ClientSession(*streams)
            self.session = await self._session_context.__aenter__()

            # Initialize
            await self.session.initialize()
            return True
        except Exception as e:
            logger.error(f"Server {self.server_name}: SSE connection error: {str(e)}")
            return False

    async def list_tools(self):
        if not self.session:
            return []
        try:
            response = await self.session.list_tools()
            # 将 pydantic 模型转换为字典格式
            self.tools = [
                Tool(tool.name, tool.description, tool.inputSchema) for tool in response.tools
            ]
            return self.tools
        except Exception as e:
            logger.error(f"Server {self.server_name}: List tools error: {str(e)}")
            return []

    async def call_tool(self, tool_name: str, arguments: dict, retries: int = 2, delay: float = 1.0,):
        if not self.session:
            return {"error": "MCP SSE Not connected"}

        if not self.session:
            raise RuntimeError(f"Server {self.name} not initialized")

        attempt = 0
        while attempt < retries:
            try:
                logger.info(f"开始使用SSE MCP协议调用工具,tool_name: {tool_name}, arguments: {arguments}")
                response = await self.session.call_tool(tool_name, arguments)
                # 将 pydantic 模型转换为字典格式
                return response.model_dump() if hasattr(response, 'model_dump') else response
            except ClosedResourceError as e:
                logger.warning(f"Session closed: {e.__repr__()}, attempting to restart session.")
                await self.cleanup()
                status = await self.start()  # 重新建立 session
                if status:
                    logger.info("Session restarted successfully.")
                else:
                    logger.error("Failed to restart session.")
                attempt += 1
                await asyncio.sleep(delay)
            except Exception as e:
                attempt += 1
                logger.warning(
                    f"Error executing tool: {e.__repr__()}. Attempt {attempt} of {retries}."
                )
                if attempt < retries:
                    logger.info(f"Retrying in {delay} seconds...")
                    await asyncio.sleep(delay)
                else:
                    logger.error("Max retries reached. Failing.")
                    raise
    async def cleanup(self):
        try:
            if self.session:
                await self._session_context.__aexit__(None, None, None)
            if self._streams_context:
                await self._streams_context.__aexit__(None, None, None)
        except Exception as e:
            logger.warning(f"Error cleaning up SSE client: {e.__repr__()}")


class MCPClient:
    """Manages MCP server connections and tool execution."""

    def __init__(self, server_name: str, command, args=None, env=None) -> None:
        self.name: str = server_name
        self.config: dict[str, Any] = {"command": command, "args": args, "env": env}
        self.stdio_context: Any | None = None
        self.session: ClientSession | None = None
        self._cleanup_lock: asyncio.Lock = asyncio.Lock()
        self.exit_stack: AsyncExitStack = AsyncExitStack()
        self.tools = []

    async def start(self) -> bool:
        """Initialize the server connection."""
        command = (
            shutil.which("npx")
            if self.config["command"] == "npx"
            else self.config["command"]
        )
        if command is None:
            raise ValueError("The command must be a valid string and cannot be None.")

        server_params = StdioServerParameters(
            command=command,
            args=self.config["args"],
            env={**os.environ, **self.config["env"]}
            if self.config.get("env")
            else None,
        )
        try:
            stdio_transport = await self.exit_stack.enter_async_context(
                stdio_client(server_params)
            )
            read, write = stdio_transport
            session = await self.exit_stack.enter_async_context(
                ClientSession(read, write)
            )
            await session.initialize()
            self.session = session
            return True
        except Exception as e:
            logger.error(f"Error initializing server {self.name}: {e}")
            await self.cleanup()
            return False

    async def list_tools(self) -> list[Any]:
        """List available tools from the server.

        Returns:
            A list of available tools.

        Raises:
            RuntimeError: If the server is not initialized.
        """
        if not self.session:
            raise RuntimeError(f"Server {self.name} not initialized")

        tools_response = await self.session.list_tools()
        tools = []

        for item in tools_response:
            if isinstance(item, tuple) and item[0] == "tools":
                tools.extend(
                    Tool(tool.name, tool.description, tool.inputSchema)
                    for tool in item[1]
                )
        # 工具名称
        self.tools = tools
        return tools

    async def call_tool(
        self,
        tool_name: str,
        arguments: dict[str, Any],
        retries: int = 2,
        delay: float = 1.0,
    ) -> Any:
        """Execute a tool with retry mechanism.

        Args:
            tool_name: Name of the tool to execute.
            arguments: Tool arguments.
            retries: Number of retry attempts.
            delay: Delay between retries in seconds.

        Returns:
            Tool execution result.

        Raises:
            RuntimeError: If server is not initialized.
            Exception: If tool execution fails after all retries.
        """
        if not self.session:
            raise RuntimeError(f"Server {self.name} not initialized")

        attempt = 0
        while attempt < retries:
            try:
                logger.info(f"Executing {tool_name}...")
                response = await self.session.call_tool(tool_name, arguments)
                return response.model_dump() if hasattr(response, 'model_dump') else response
            except ClosedResourceError as e:
                logger.warning(f"Session closed: {e.__repr__()}, attempting to restart session.")
                await self.cleanup()
                status = await self.start()  # 重新建立 session
                if status:
                    logger.info("Session restarted successfully.")
                else:
                    logger.error("Failed to restart session.")
                attempt += 1
                await asyncio.sleep(delay)
            except Exception as e:
                attempt += 1
                logger.warning(
                    f"Error executing tool: {e}. Attempt {attempt} of {retries}."
                )
                if attempt < retries:
                    logger.info(f"Retrying in {delay} seconds...")
                    await asyncio.sleep(delay)
                else:
                    logger.error("Max retries reached. Failing.")
                    raise

    async def cleanup(self) -> None:
        """Clean up server resources."""
        async with self._cleanup_lock:
            try:
                await self.exit_stack.aclose()
                self.session = None
                self.stdio_context = None
            except Exception as e:
                logger.error(f"Error during cleanup of server {self.name}: {e}")

johnson7788 avatar May 23 '25 23:05 johnson7788