Skip to content

Commit 1bdfd6d

Browse files
authored
fix: extract client info per-request in stateless mode to prevent cross-user bleed (#28)
1 parent 3d73bf7 commit 1bdfd6d

File tree

7 files changed

+144
-48
lines changed

7 files changed

+144
-48
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "mcpcat"
3-
version = "0.1.15b1"
3+
version = "0.1.15b2"
44
description = "Analytics Tool for MCP Servers - provides insights into MCP tool usage patterns"
55
authors = [
66
{ name = "MCPCat", email = "support@mcpcat.io" },

src/mcpcat/modules/overrides/community/monkey_patch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,10 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult:
9898

9999
# Handle session identification
100100
try:
101-
get_client_info_from_request_context(lowlevel_server, request_context)
101+
client_name, client_version = get_client_info_from_request_context(lowlevel_server, request_context)
102102
identity = identify_session(lowlevel_server, request, request_context)
103103
except Exception as e:
104+
client_name, client_version = None, None
104105
identity = None
105106
write_to_log(f"Non-critical error in session handling: {e}")
106107

@@ -124,6 +125,8 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult:
124125
identify_actor_given_id=identity.user_id if identity else None,
125126
identify_actor_name=identity.user_name if identity else None,
126127
identify_data=identity.user_data if identity else None,
128+
client_name=client_name,
129+
client_version=client_version,
127130
)
128131

129132
try:

src/mcpcat/modules/overrides/community_v3/middleware.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,20 +93,24 @@ async def on_initialize(
9393
session_id = self._get_session_id()
9494
params = context.message.params
9595

96-
# Extract client info from initialize params
96+
# Extract client info from initialize params (MCP protocol provides clientInfo here)
97+
client_name, client_version = None, None
9798
if params and hasattr(params, "clientInfo") and params.clientInfo:
9899
client_info = params.clientInfo
99100
if hasattr(client_info, "name") and client_info.name:
100-
self.mcpcat_data.session_info.client_name = client_info.name
101+
client_name = client_info.name
101102
if hasattr(client_info, "version") and client_info.version:
102-
self.mcpcat_data.session_info.client_version = client_info.version
103+
client_version = client_info.version
103104

104105
# Handle session identification
105106
# Note: Use self.server (FastMCP) not self.server._mcp_server because
106107
# tracking data is stored with the FastMCP server as the key for v3
107108
request_context = self._get_request_context(context)
108109
try:
109-
get_client_info_from_request_context(self.server, request_context)
110+
if not client_name:
111+
client_name, client_version = get_client_info_from_request_context(self.server, request_context)
112+
else:
113+
get_client_info_from_request_context(self.server, request_context)
110114
identity = identify_session(self.server, context.message, request_context)
111115
except Exception as e:
112116
identity = None
@@ -120,6 +124,8 @@ async def on_initialize(
120124
identify_actor_given_id=identity.user_id if identity else None,
121125
identify_actor_name=identity.user_name if identity else None,
122126
identify_data=identity.user_data if identity else None,
127+
client_name=client_name,
128+
client_version=client_version,
123129
)
124130

125131
try:
@@ -157,9 +163,10 @@ async def on_call_tool(
157163
# tracking data is stored with the FastMCP server as the key for v3
158164
request_context = self._get_request_context(context)
159165
try:
160-
get_client_info_from_request_context(self.server, request_context)
166+
client_name, client_version = get_client_info_from_request_context(self.server, request_context)
161167
identity = identify_session(self.server, context.message, request_context)
162168
except Exception as e:
169+
client_name, client_version = None, None
163170
identity = None
164171
write_to_log(f"Non-critical error in session handling: {e}")
165172

@@ -188,6 +195,8 @@ async def on_call_tool(
188195
identify_actor_given_id=identity.user_id if identity else None,
189196
identify_actor_name=identity.user_name if identity else None,
190197
identify_data=identity.user_data if identity else None,
198+
client_name=client_name,
199+
client_version=client_version,
191200
)
192201

193202
# Create modified context without context parameter if needed
@@ -248,9 +257,10 @@ async def on_list_tools(
248257
# tracking data is stored with the FastMCP server as the key for v3
249258
request_context = self._get_request_context(context)
250259
try:
251-
get_client_info_from_request_context(self.server, request_context)
260+
client_name, client_version = get_client_info_from_request_context(self.server, request_context)
252261
identity = identify_session(self.server, context.message, request_context)
253262
except Exception as e:
263+
client_name, client_version = None, None
254264
identity = None
255265
write_to_log(f"Non-critical error in session handling: {e}")
256266

@@ -264,6 +274,8 @@ async def on_list_tools(
264274
identify_actor_given_id=identity.user_id if identity else None,
265275
identify_actor_name=identity.user_name if identity else None,
266276
identify_data=identity.user_data if identity else None,
277+
client_name=client_name,
278+
client_version=client_version,
267279
)
268280

269281
try:

src/mcpcat/modules/overrides/mcp_server.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ async def wrapped_initialize_handler(request: InitializeRequest) -> ServerResult
4343
request_context = safe_request_context(server)
4444
identity = identify_session(server, request, request_context)
4545

46+
# Extract clientInfo from InitializeRequest params (MCP protocol provides it here)
47+
client_name, client_version = None, None
48+
if request.params and hasattr(request.params, 'clientInfo') and request.params.clientInfo:
49+
client_name = request.params.clientInfo.name
50+
client_version = getattr(request.params.clientInfo, 'version', None)
51+
if not client_name:
52+
client_name, client_version = get_client_info_from_request_context(server, request_context)
53+
4654
event = UnredactedEvent(
4755
session_id=session_id,
4856
timestamp=datetime.now(timezone.utc),
@@ -51,13 +59,13 @@ async def wrapped_initialize_handler(request: InitializeRequest) -> ServerResult
5159
identify_actor_given_id=identity.user_id if identity else None,
5260
identify_actor_name=identity.user_name if identity else None,
5361
identify_data=identity.user_data if identity else None,
62+
client_name=client_name,
63+
client_version=client_version,
5464
)
5565

5666
# Call the original handler
5767
result = await original_initialize_handler(request)
5868

59-
# TODO: Grab client and server information from the request
60-
6169
# Record the event
6270
event.response = result.model_dump() if result else None
6371
event_queue.publish_event(server, event)
@@ -67,7 +75,7 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ServerResult:
6775
"""Intercept list_tools requests to add MCPCat tools and modify existing ones."""
6876
session_id = get_server_session_id(server)
6977
request_context = safe_request_context(server)
70-
get_client_info_from_request_context(server, request_context)
78+
client_name, client_version = get_client_info_from_request_context(server, request_context)
7179
identity = identify_session(server, request, request_context)
7280

7381
event = UnredactedEvent(
@@ -80,6 +88,8 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ServerResult:
8088
identify_actor_given_id=identity.user_id if identity else None,
8189
identify_actor_name=identity.user_name if identity else None,
8290
identify_data=identity.user_data if identity else None,
91+
client_name=client_name,
92+
client_version=client_version,
8393
)
8494

8595
# Call the original handler to get the tools
@@ -149,7 +159,7 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult:
149159
arguments = request.params.arguments or {}
150160
session_id = get_server_session_id(server)
151161
request_context = safe_request_context(server)
152-
get_client_info_from_request_context(server, request_context)
162+
client_name, client_version = get_client_info_from_request_context(server, request_context)
153163
identity = identify_session(server, request, request_context)
154164

155165
write_to_log(
@@ -164,6 +174,8 @@ async def wrapped_call_tool_handler(request: CallToolRequest) -> ServerResult:
164174
identify_actor_given_id=identity.user_id if identity else None,
165175
identify_actor_name=identity.user_name if identity else None,
166176
identify_data=identity.user_data if identity else None,
177+
client_name=client_name,
178+
client_version=client_version,
167179
)
168180

169181
# Extract user intent from context (but don't pop yet - we need it for the event)
@@ -237,6 +249,13 @@ async def wrapped_initialize_handler(request: InitializeRequest) -> ServerResult
237249
identity = None
238250
write_to_log(f"Ran into an error in session identification, no identity could be determined: {e}")
239251

252+
client_name, client_version = None, None
253+
if request.params and hasattr(request.params, 'clientInfo') and request.params.clientInfo:
254+
client_name = request.params.clientInfo.name
255+
client_version = getattr(request.params.clientInfo, 'version', None)
256+
if not client_name:
257+
client_name, client_version = get_client_info_from_request_context(server, request_context)
258+
240259
event = UnredactedEvent(
241260
session_id=session_id,
242261
timestamp=datetime.now(timezone.utc),
@@ -245,6 +264,8 @@ async def wrapped_initialize_handler(request: InitializeRequest) -> ServerResult
245264
identify_actor_given_id=identity.user_id if identity else None,
246265
identify_actor_name=identity.user_name if identity else None,
247266
identify_data=identity.user_data if identity else None,
267+
client_name=client_name,
268+
client_version=client_version,
248269
)
249270

250271
# Call the original handler
@@ -259,7 +280,7 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ServerResult:
259280
"""Intercept list_tools requests to track the event (tool modifications handled by monkey-patch)."""
260281
session_id = get_server_session_id(server)
261282
request_context = safe_request_context(server)
262-
get_client_info_from_request_context(server, request_context)
283+
client_name, client_version = get_client_info_from_request_context(server, request_context)
263284
identity = identify_session(server, request, request_context)
264285

265286
event = UnredactedEvent(
@@ -272,6 +293,8 @@ async def wrapped_list_tools_handler(request: ListToolsRequest) -> ServerResult:
272293
identify_actor_given_id=identity.user_id if identity else None,
273294
identify_actor_name=identity.user_name if identity else None,
274295
identify_data=identity.user_data if identity else None,
296+
client_name=client_name,
297+
client_version=client_version,
275298
)
276299

277300
# Call the original handler - tool modifications are handled by monkey-patch

src/mcpcat/modules/overrides/official/monkey_patch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -239,9 +239,9 @@ async def patched_call_tool(
239239
# Handle session identification (non-critical)
240240
try:
241241
request_context = safe_request_context(server._mcp_server)
242-
# Only call if request_context is not None
242+
client_name, client_version = (None, None)
243243
if request_context is not None:
244-
get_client_info_from_request_context(
244+
client_name, client_version = get_client_info_from_request_context(
245245
server._mcp_server, request_context
246246
)
247247

@@ -261,9 +261,9 @@ async def patched_call_tool(
261261

262262
identity = identify_session(server._mcp_server, mock_request, request_context)
263263
except Exception as e:
264+
client_name, client_version = None, None
264265
identity = None
265266
write_to_log(f"Non-critical error in session handling: {e}")
266-
# Continue without session identification
267267

268268
# Extract user intent (non-critical)
269269
user_intent = None
@@ -298,6 +298,8 @@ async def patched_call_tool(
298298
identify_actor_given_id=identity.user_id if identity else None,
299299
identify_actor_name=identity.user_name if identity else None,
300300
identify_data=identity.user_data if identity else None,
301+
client_name=client_name,
302+
client_version=client_version,
301303
)
302304
except Exception as e:
303305
write_to_log(f"Error creating event: {e}")

src/mcpcat/modules/session.py

Lines changed: 40 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -59,80 +59,89 @@ def get_headers_from_request_context(
5959

6060
def get_client_info_from_request_context(
6161
server: Server, request_context: RequestContext | None
62-
) -> None:
62+
) -> tuple[str | None, str | None]:
6363
"""Extract client information from request context or HTTP headers.
6464
65+
Returns (client_name, client_version). In stateless mode, extracts per-request
66+
without caching. In stateful mode, caches on shared session_info.
67+
6568
This function is designed to be resilient and never fail - any error is logged
6669
but won't affect the server operation.
6770
"""
6871
# Handle None request_context (e.g., in stateless HTTP mode outside handlers)
6972
if request_context is None:
7073
write_to_log("Request context is None, skipping client info extraction")
71-
return
74+
return (None, None)
7275

7376
try:
7477
data = get_server_tracking_data(server)
7578
if not data:
76-
return
79+
return (None, None)
80+
81+
client_name: str | None = None
82+
client_version: str | None = None
7783

78-
# If client name and version are already set, no need to fetch again
79-
if data.session_info.client_name and data.session_info.client_version:
80-
return
84+
# In stateful mode, return cached values if already set
85+
if not data.is_stateless and data.session_info.client_name and data.session_info.client_version:
86+
return (data.session_info.client_name, data.session_info.client_version)
8187

8288
try:
83-
# Try to get from session (stateful mode)
89+
# Try to get from MCP session (stateful mode)
8490
if hasattr(request_context, "session") and request_context.session:
8591
client_info = request_context.session.client_params.clientInfo
8692
if client_info:
87-
data.session_info.client_name = client_info.name
88-
data.session_info.client_version = client_info.version
89-
set_server_tracking_data(server, data)
90-
return
91-
except (AttributeError, TypeError) as e:
93+
client_name = client_info.name
94+
client_version = client_info.version
95+
if not data.is_stateless:
96+
data.session_info.client_name = client_name
97+
data.session_info.client_version = client_version
98+
set_server_tracking_data(server, data)
99+
return (client_name, client_version)
100+
except (AttributeError, TypeError):
92101
# This is expected in stateless mode, just continue
93102
pass
94103
except Exception as e:
95-
# Unexpected error, log but continue
96104
write_to_log(f"Error extracting client info from session: {e}")
97105

98106
# Fallback: Try to extract from HTTP headers (stateless mode)
99107
try:
100108
headers = get_headers_from_request_context(request_context)
101109
if headers:
102-
# Check User-Agent header
110+
# Parse User-Agent header (format: "ClientName/Version ...")
103111
user_agent = headers.get("user-agent", "")
104112
if user_agent:
105-
# Parse User-Agent for client info
106-
# Format could be: "ClientName/Version (additional info)"
107113
match = re.match(r"^([^/]+)/([^\s]+)", user_agent)
108114
if match:
109-
data.session_info.client_name = match.group(1)
110-
data.session_info.client_version = match.group(2)
115+
client_name = match.group(1)
116+
client_version = match.group(2)
111117
else:
112-
# If no neat match, use the whole string as client_name
113-
data.session_info.client_name = user_agent
118+
# No neat match, use the whole string as client_name
119+
client_name = user_agent
114120

115-
# Also check custom MCP headers if any
116-
# Clients might send: X-MCP-Client-Name, X-MCP-Client-Version
121+
# Custom MCP headers override User-Agent if present
117122
if headers.get("x-mcp-client-name"):
118-
data.session_info.client_name = headers.get("x-mcp-client-name")
123+
client_name = headers.get("x-mcp-client-name")
119124
if headers.get("x-mcp-client-version"):
120-
data.session_info.client_version = headers.get(
121-
"x-mcp-client-version"
122-
)
125+
client_version = headers.get("x-mcp-client-version")
123126

124-
if data.session_info.client_name or data.session_info.client_version:
127+
if not data.is_stateless and (client_name or client_version):
128+
data.session_info.client_name = client_name
129+
data.session_info.client_version = client_version
125130
set_server_tracking_data(server, data)
131+
132+
if client_name or client_version:
126133
write_to_log(
127-
f"Extracted client info from headers: {data.session_info.client_name} v{data.session_info.client_version}"
134+
f"Extracted client info from headers: {client_name} v{client_version}"
128135
)
129136
except Exception as e:
130137
write_to_log(f"Error extracting client info from headers: {e}")
131138
# Continue without client info
139+
140+
return (client_name, client_version)
132141
except Exception as e:
133142
# Catch-all for any unexpected errors - log but never fail
134143
write_to_log(f"Unexpected error in get_client_info_from_request_context: {e}")
135-
# Function continues and returns normally
144+
return (None, None)
136145

137146

138147
def get_session_info(server: Server, data: MCPCatData | None = None) -> SessionInfo:
@@ -148,10 +157,10 @@ def get_session_info(server: Server, data: MCPCatData | None = None) -> SessionI
148157
server_name=server.name if hasattr(server, "name") else None,
149158
server_version=server.version if hasattr(server, "version") else None,
150159
client_name=data.session_info.client_name
151-
if data and data.session_info
160+
if data and data.session_info and not data.is_stateless
152161
else None,
153162
client_version=data.session_info.client_version
154-
if data and data.session_info
163+
if data and data.session_info and not data.is_stateless
155164
else None,
156165
identify_actor_given_id=actor_info.user_id if actor_info else None,
157166
identify_actor_name=actor_info.user_name if actor_info else None,

0 commit comments

Comments
 (0)