Skip to main content

Cluster-Aware Retrieval for RAG Systems

Most RAG systems treat embedding spaces as flat, uniform distributions. They’re not. Real knowledge bases contain distinct semantic clusters—database docs, frontend frameworks, DevOps practices—with different internal structures. Ignoring this wastes retrieval precision.

The Problem with Flat Retrieval

A query about “React hooks optimization” should pull from the frontend cluster, not equally consider database or infrastructure docs that happen to share semantic overlap. Standard cosine similarity doesn’t care about topical boundaries. You get results that are individually relevant but collectively unfocused.

Modeling Clusters with GMM

Gaussian Mixture Models assume your embeddings arise from \(K\) underlying Gaussian distributions:

$$p(v) = \sum_{k=1}^K \pi_k \mathcal{N}(v \mid \mu_k, \Sigma_k)$$

For a query \(q\), compute the posterior probability of each cluster:

$$p(k \mid q) = \frac{\pi_k \mathcal{N}(q \mid \mu_k, \Sigma_k)}{\sum_{j=1}^K \pi_j \mathcal{N}(q \mid \mu_j, \Sigma_j)}$$

This gives you soft assignments—the probability that a query belongs to each semantic cluster.

Two-Stage Retrieval

  1. Cluster selection: Pick cluster(s) with highest \(p(k \mid q)\). Take top-2 for ambiguous queries.
  2. Intra-cluster retrieval: Run k-NN within selected clusters.

The cluster boundaries act as a soft filter, avoiding the “dilution effect” where off-topic documents dominate results.

Mahalanobis Distance Per Cluster

Here’s the underexplored idea: different clusters can use different distance metrics. For a cluster modeled as \(\mathcal{N}(\mu_k, \Sigma_k)\), the Mahalanobis distance accounts for the cluster’s shape:

$$d_{\text{Mah}}(q, v) = \sqrt{(q - v)^T \Sigma_k^{-1} (q - v)}$$

Elongated clusters in certain semantic directions get stretched appropriately. Cosine similarity treats all directions equally—Mahalanobis adapts.

Clusters as Agent Tools

In agentic RAG, each cluster becomes a tool the agent can invoke:

tools = [
    ClusterRetrievalTool(cluster_id=k, name=f"Search {topic_k}")
    for k in range(K)
]

The agent decides which clusters to search and in what order:

  • Query: “How does React’s context API compare to Redux?”
  • Agent plan:
    1. Search frontend cluster for React context
    2. Search state management cluster for Redux patterns
    3. Synthesize comparison

This beats flat retrieval for cross-topic synthesis.

Implementation

Fit GMM offline on document embeddings:

from sklearn.mixture import GaussianMixture

gmm = GaussianMixture(n_components=K, covariance_type='full')
gmm.fit(document_embeddings)

# For query q:
cluster_probs = gmm.predict_proba(q.reshape(1, -1))[0]
selected_clusters = cluster_probs.argsort()[-2:][::-1]  # top-2

Store cluster assignments as metadata in your vector DB:

results = vector_db.query(
    query_embedding=q,
    filter={"cluster_id": {"$in": selected_clusters}},
    top_k=20
)

Key decisions:

  • Number of clusters: Use BIC/AIC or domain knowledge
  • Regularization: Add \(\lambda I\) to covariance matrices to prevent singularities
  • Initialization: k-means++ for better convergence

When It Helps

  • Topically diverse corpora: Multi-product docs, cross-domain papers
  • Single-topic queries: Clear primary topic to route to
  • Noise reduction: Distant-but-similar content diluting results

When it doesn’t:

  • Homogeneous corpora
  • Very small datasets
  • Queries requiring extensive cross-topic synthesis (agentic patterns help here)

Limitations

Cluster boundaries: Queries near boundaries may be misrouted. Soft routing (weighted retrieval across clusters) helps.

Scalability: GMM fitting doesn’t scale well beyond ~100 clusters and millions of docs. Use hierarchical clustering or vector DB partitioning for large systems.

Benchmark first: Flat retrieval with strong reranking is a tough baseline. Always compare.

The core insight: embedding spaces have structure. Exploit it.

Discussion