Skip to content

Commit 22054ea

Browse files
perf(mcp): add URL-keyed connection pooling for Redis database tools (#808)
Add a connection cache to CachedClients (keyed by resolved URL) that returns a shared MultiplexedConnection, matching the existing pattern used for Cloud and Enterprise API clients. - Add `database: HashMap<String, MultiplexedConnection>` to CachedClients - Add `redis_connection_for_url()` on AppState with PING health check and automatic reconnect on stale connections - Add `get_connection()` helper in tools/redis/mod.rs that combines URL resolution and cached connection retrieval - Replace per-handler Client::open + get_multiplexed_async_connection boilerplate across all 55 database tool handlers Connections are keyed by URL (not profile name) so different profiles pointing to the same database share a connection, and raw URL callers benefit from caching too. Closes #801
1 parent cef87d9 commit 22054ea

7 files changed

Lines changed: 246 additions & 742 deletions

File tree

crates/redisctl-mcp/src/state.rs

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//! Application state and credential resolution
22
3-
#[cfg(any(feature = "cloud", feature = "enterprise"))]
3+
#[cfg(any(feature = "cloud", feature = "enterprise", feature = "database"))]
44
use std::collections::HashMap;
55
use std::sync::Arc;
66

@@ -30,12 +30,14 @@ pub enum CredentialSource {
3030
},
3131
}
3232

33-
/// Cached API clients (per-profile for multi-cluster support)
33+
/// Cached API clients and connections (per-profile for multi-cluster support)
3434
pub struct CachedClients {
3535
#[cfg(feature = "cloud")]
3636
pub cloud: HashMap<String, CloudClient>,
3737
#[cfg(feature = "enterprise")]
3838
pub enterprise: HashMap<String, EnterpriseClient>,
39+
#[cfg(feature = "database")]
40+
pub database: HashMap<String, redis::aio::MultiplexedConnection>,
3941
}
4042

4143
/// Shared application state
@@ -85,6 +87,8 @@ impl AppState {
8587
cloud: HashMap::new(),
8688
#[cfg(feature = "enterprise")]
8789
enterprise: HashMap::new(),
90+
#[cfg(feature = "database")]
91+
database: HashMap::new(),
8892
}),
8993
})
9094
}
@@ -340,23 +344,46 @@ impl AppState {
340344
Ok(format!("{}://{}{}:{}{}", scheme, auth, host, port, db_path))
341345
}
342346

343-
/// Get Redis connection for direct database operations
347+
/// Get or create a cached Redis connection for a resolved URL.
348+
///
349+
/// Connections are cached by URL. If a cached connection fails a PING
350+
/// health check, it is evicted and a fresh connection is created.
344351
#[cfg(feature = "database")]
345-
#[allow(dead_code)]
346-
pub async fn redis_connection(&self) -> Result<redis::aio::MultiplexedConnection> {
347-
let url = self
348-
.database_url
349-
.as_ref()
350-
.cloned()
351-
.or_else(|| std::env::var("REDIS_URL").ok())
352-
.context("No Redis URL configured")?;
353-
354-
let client = redis::Client::open(url.as_str()).context("Failed to create Redis client")?;
352+
pub async fn redis_connection_for_url(
353+
&self,
354+
url: &str,
355+
) -> Result<redis::aio::MultiplexedConnection> {
356+
// Check cache first
357+
{
358+
let clients = self.clients.read().await;
359+
if let Some(conn) = clients.database.get(url) {
360+
// Quick health check -- if PING fails the connection is stale
361+
let mut test_conn = conn.clone();
362+
if redis::cmd("PING")
363+
.query_async::<String>(&mut test_conn)
364+
.await
365+
.is_ok()
366+
{
367+
return Ok(conn.clone());
368+
}
369+
// Fall through to evict + reconnect
370+
}
371+
}
355372

356-
client
373+
// Create new connection (or reconnect after eviction)
374+
let client = redis::Client::open(url).context("Failed to create Redis client")?;
375+
let conn = client
357376
.get_multiplexed_async_connection()
358377
.await
359-
.context("Failed to connect to Redis")
378+
.context("Failed to connect to Redis")?;
379+
380+
// Cache it
381+
{
382+
let mut clients = self.clients.write().await;
383+
clients.database.insert(url.to_string(), conn.clone());
384+
}
385+
386+
Ok(conn)
360387
}
361388

362389
/// Check if write operations are allowed by the global policy tier.
@@ -395,6 +422,8 @@ impl Clone for AppState {
395422
cloud: HashMap::new(),
396423
#[cfg(feature = "enterprise")]
397424
enterprise: HashMap::new(),
425+
#[cfg(feature = "database")]
426+
database: HashMap::new(),
398427
}),
399428
}
400429
}
@@ -427,6 +456,8 @@ impl AppState {
427456
cloud,
428457
#[cfg(feature = "enterprise")]
429458
enterprise: HashMap::new(),
459+
#[cfg(feature = "database")]
460+
database: HashMap::new(),
430461
}),
431462
}
432463
}
@@ -446,6 +477,8 @@ impl AppState {
446477
#[cfg(feature = "cloud")]
447478
cloud: HashMap::new(),
448479
enterprise,
480+
#[cfg(feature = "database")]
481+
database: HashMap::new(),
449482
}),
450483
}
451484
}
@@ -466,6 +499,8 @@ impl AppState {
466499
clients: RwLock::new(CachedClients {
467500
cloud: cloud_map,
468501
enterprise: enterprise_map,
502+
#[cfg(feature = "database")]
503+
database: HashMap::new(),
469504
}),
470505
}
471506
}

crates/redisctl-mcp/src/tools/redis/diagnostics.rs

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,8 @@ pub fn health_check(state: Arc<AppState>) -> Tool {
106106
.extractor_handler_typed::<_, _, _, HealthCheckInput>(
107107
state,
108108
|State(state): State<Arc<AppState>>, Json(input): Json<HealthCheckInput>| async move {
109-
let url = super::resolve_redis_url(input.url, input.profile.as_deref(), &state)?;
110-
111-
let client = redis::Client::open(url.as_str()).tool_context("Invalid URL")?;
112-
113-
let mut conn = client
114-
.get_multiplexed_async_connection()
115-
.await
116-
.tool_context("Connection failed")?;
109+
let mut conn =
110+
super::get_connection(input.url, input.profile.as_deref(), &state).await?;
117111

118112
// PING
119113
let ping_response: String = redis::cmd("PING")
@@ -248,14 +242,8 @@ pub fn key_summary(state: Arc<AppState>) -> Tool {
248242
.extractor_handler_typed::<_, _, _, KeySummaryInput>(
249243
state,
250244
|State(state): State<Arc<AppState>>, Json(input): Json<KeySummaryInput>| async move {
251-
let url = super::resolve_redis_url(input.url, input.profile.as_deref(), &state)?;
252-
253-
let client = redis::Client::open(url.as_str()).tool_context("Invalid URL")?;
254-
255-
let mut conn = client
256-
.get_multiplexed_async_connection()
257-
.await
258-
.tool_context("Connection failed")?;
245+
let mut conn =
246+
super::get_connection(input.url, input.profile.as_deref(), &state).await?;
259247

260248
// TYPE
261249
let key_type: String = redis::cmd("TYPE")
@@ -368,14 +356,8 @@ pub fn hotkeys(state: Arc<AppState>) -> Tool {
368356
.extractor_handler_typed::<_, _, _, HotkeysInput>(
369357
state,
370358
|State(state): State<Arc<AppState>>, Json(input): Json<HotkeysInput>| async move {
371-
let url = super::resolve_redis_url(input.url, input.profile.as_deref(), &state)?;
372-
373-
let client = redis::Client::open(url.as_str()).tool_context("Invalid URL")?;
374-
375-
let mut conn = client
376-
.get_multiplexed_async_connection()
377-
.await
378-
.tool_context("Connection failed")?;
359+
let mut conn =
360+
super::get_connection(input.url, input.profile.as_deref(), &state).await?;
379361

380362
let pattern = input.pattern.as_deref().unwrap_or("*");
381363
let sample_size = input.sample_size.unwrap_or(1000).min(MAX_SAMPLE_SIZE);
@@ -511,15 +493,8 @@ pub fn connection_summary(state: Arc<AppState>) -> Tool {
511493
state,
512494
|State(state): State<Arc<AppState>>,
513495
Json(input): Json<ConnectionSummaryInput>| async move {
514-
let url = super::resolve_redis_url(input.url, input.profile.as_deref(), &state)?;
515-
516-
let client = redis::Client::open(url.as_str())
517-
.tool_context("Invalid URL")?;
518-
519-
let mut conn = client
520-
.get_multiplexed_async_connection()
521-
.await
522-
.tool_context("Connection failed")?;
496+
let mut conn =
497+
super::get_connection(input.url, input.profile.as_deref(), &state).await?;
523498

524499
// CLIENT LIST
525500
let client_list_raw: String = redis::cmd("CLIENT")

0 commit comments

Comments
 (0)