# Copyright 2026 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import asyncio from colab_mcp import session from fastmcp.server.middleware import MiddlewareContext import pytest from unittest.mock import patch, AsyncMock, Mock @pytest.fixture(autouse=True) def mock_webbrowser(monkeypatch): import webbrowser mock_open = Mock() monkeypatch.setattr(webbrowser, "open_new", mock_open) return mock_open @pytest.fixture def mock_wss(): """Provides a mock ColabWebSocketServer instance.""" return MockColabWebSocketServer() class MockColabWebSocketServer: def __init__(self): self.connection_live = asyncio.Event() self.read_stream = AsyncMock() self.write_stream = AsyncMock() self.token = "test-token" self.port = 1234 async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): pass @pytest.fixture def mock_proxy_client(mock_wss): client = Mock(spec=session.ColabProxyClient) client.wss = mock_wss client.is_connected.return_value = False return client class TestColabProxyMiddleware: @pytest.mark.asyncio async def test_connection_live(self, mock_proxy_client): """Tests connection state change from disconnected to connected.""" middleware = session.ColabProxyMiddleware(mock_proxy_client) mock_proxy_client.is_connected.return_value = True context = Mock(spec=MiddlewareContext) context.fastmcp_context.set_state = Mock() context.fastmcp_context.send_tool_list_changed = AsyncMock() call_next = AsyncMock() await middleware.on_message(context, call_next) call_next.assert_called_once_with(context) context.fastmcp_context.set_state.assert_any_call("fe_connected", True) context.fastmcp_context.set_state.assert_any_call("proxy_token", "test-token") context.fastmcp_context.set_state.assert_any_call("proxy_port", 1234) assert middleware.last_message_connected is True context.fastmcp_context.send_tool_list_changed.assert_called_once() @pytest.mark.asyncio async def test_connection_not_live(self, mock_proxy_client): """Tests connection state change from connected to disconnected.""" mock_proxy_client.is_connected.return_value = True middleware = session.ColabProxyMiddleware(mock_proxy_client) mock_proxy_client.is_connected.return_value = False context = Mock(spec=MiddlewareContext) context.fastmcp_context.set_state = Mock() context.fastmcp_context.send_tool_list_changed = AsyncMock() call_next = AsyncMock() await middleware.on_message(context, call_next) call_next.assert_called_once_with(context) context.fastmcp_context.set_state.assert_any_call("fe_connected", False) assert middleware.last_message_connected is False context.fastmcp_context.send_tool_list_changed.assert_called_once() @pytest.mark.asyncio async def test_no_connection_change(self, mock_proxy_client): """Tests no connection state change.""" mock_proxy_client.is_connected.return_value = True middleware = session.ColabProxyMiddleware(mock_proxy_client) context = Mock(spec=MiddlewareContext) context.fastmcp_context.set_state = Mock() context.fastmcp_context.send_tool_list_changed = AsyncMock() call_next = AsyncMock() await middleware.on_message(context, call_next) call_next.assert_called_once_with(context) context.fastmcp_context.set_state.assert_any_call("fe_connected", True) assert middleware.last_message_connected is True context.fastmcp_context.send_tool_list_changed.assert_not_called() @pytest.mark.asyncio async def test_on_call_tool_await_connection(self, mock_proxy_client): middleware = session.ColabProxyMiddleware(mock_proxy_client) context = Mock() context.fastmcp_context.report_progress = AsyncMock() context.message.name = session.INJECTED_TOOL_NAME mock_proxy_client.is_connected.side_effect = [False, True] mock_proxy_client.await_proxy_connection = AsyncMock() call_next = AsyncMock() result = await middleware.on_call_tool(context, call_next) mock_proxy_client.await_proxy_connection.assert_called_once() context.fastmcp_context.report_progress.assert_called() assert result.structured_content == {"result": True} @pytest.mark.asyncio async def test_on_call_tool_timeout(self, mock_proxy_client): middleware = session.ColabProxyMiddleware(mock_proxy_client) context = Mock() context.fastmcp_context.report_progress = AsyncMock() context.message.name = session.INJECTED_TOOL_NAME mock_proxy_client.is_connected.return_value = False mock_proxy_client.await_proxy_connection = AsyncMock() call_next = AsyncMock() result = await middleware.on_call_tool(context, call_next) mock_proxy_client.await_proxy_connection.assert_called_once() assert result.structured_content == {"result": False} class TestCheckSessionProxyToolFn: @pytest.mark.asyncio async def test_connected(self): ctx = Mock() ctx.get_state.side_effect = ( lambda k: True if k == session.FE_CONNECTED_KEY else None ) assert await session.check_session_proxy_tool_fn(ctx) is True @pytest.mark.asyncio async def test_disconnected(self, mock_webbrowser): ctx = Mock() def get_state(k): if k == session.FE_CONNECTED_KEY: return False if k == session.PROXY_TOKEN_KEY: return "test-token" if k == session.PROXY_PORT_KEY: return 1234 return None ctx.get_state.side_effect = get_state assert await session.check_session_proxy_tool_fn(ctx) is False mock_webbrowser.assert_called_once() args, _ = mock_webbrowser.call_args assert "mcpProxyToken=test-token" in args[0] assert "mcpProxyPort=1234" in args[0] class TestColabProxyClient: def test_is_connected(self, mock_wss): client = session.ColabProxyClient(mock_wss) assert client.is_connected() is False mock_wss.connection_live.set() assert client.is_connected() is False client.proxy_mcp_client = Mock() assert client.is_connected() is True def test_client_factory_connection_live(self, mock_wss): mock_wss.connection_live.set() client = session.ColabProxyClient(mock_wss) client.proxy_mcp_client = Mock() assert client.client_factory() is client.proxy_mcp_client def test_client_factory_connection_not_live(self, mock_wss): client = session.ColabProxyClient(mock_wss) assert client.client_factory() is client.stubbed_mcp_client @pytest.mark.asyncio async def test_await_proxy_connection(self, mock_wss): client = session.ColabProxyClient(mock_wss) client._start_task = asyncio.create_task(asyncio.sleep(0.01)) mock_wss.connection_live.set() with patch("colab_mcp.session.UI_CONNECTION_TIMEOUT", 0.1): await client.await_proxy_connection() await client._start_task @pytest.mark.asyncio @patch("colab_mcp.session.Client") @patch("colab_mcp.session.ColabTransport", spec=session.ColabTransport) async def test_start_proxy_client( self, mock_colab_transport, mock_client, mock_wss ): mock_client.return_value.__aenter__ = AsyncMock() client = session.ColabProxyClient(mock_wss) mock_wss.connection_live.set() async with client: await client._start_task mock_colab_transport.assert_called_once_with(mock_wss) mock_client.assert_called_with(mock_colab_transport.return_value) class TestColabTransport: @pytest.mark.asyncio @patch("colab_mcp.session.ClientSession") async def test_connect_session(self, mock_client_session, mock_wss): transport = session.ColabTransport(mock_wss) mock_client_session.return_value.__aenter__ = AsyncMock() async with transport.connect_session(foo="bar") as client_session: assert ( client_session == mock_client_session.return_value.__aenter__.return_value ) mock_client_session.assert_called_once_with( mock_wss.read_stream, mock_wss.write_stream, foo="bar" ) class TestColabSessionProxy: @pytest.mark.asyncio @patch("colab_mcp.session.ToolInjectionMiddleware") @patch("colab_mcp.session.ColabWebSocketServer") @patch("colab_mcp.session.ColabProxyClient") @patch("colab_mcp.session.ColabProxyMiddleware") async def test_start_proxy_server( self, mock_colab_proxy_middleware, mock_colab_proxy_client, mock_colab_web_socket_server, mock_tool_injection_middleware, ): mock_colab_web_socket_server.return_value.__aenter__ = AsyncMock() mock_colab_proxy_client.return_value.__aenter__ = AsyncMock() proxy = session.ColabSessionProxy() await proxy.start_proxy_server() mock_colab_proxy_client.assert_called_once() assert proxy.proxy_server is not None mock_colab_proxy_middleware.assert_called_once() mock_tool_injection_middleware.assert_called_once() @pytest.mark.asyncio async def test_cleanup(self): proxy = session.ColabSessionProxy() proxy._exit_stack = AsyncMock() await proxy.cleanup() proxy._exit_stack.aclose.assert_called_once()