Skip to content

Commit 6b0e4e6

Browse files
committed
Cleanup
1 parent ffb0104 commit 6b0e4e6

15 files changed

Lines changed: 344 additions & 269 deletions

rust/index/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ hnswlib = { workspace = true }
3737
opentelemetry = { version = "0.27.0", default-features = false, features = ["trace", "metrics"] }
3838
simsimd = { workspace = true }
3939
dashmap = { workspace = true }
40-
half = "2"
40+
half = { workspace = true }
4141
usearch = { workspace = true, optional = true }
4242
faer = { workspace = true }
4343

rust/index/examples/sparse_vector_benchmark.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,7 @@ use chroma_benchmark::datasets::wikipedia_splade::{SparseDocument, SparseQuery,
5656
use chroma_blockstore::arrow::provider::BlockfileReaderOptions;
5757
use chroma_blockstore::test_arrow_blockfile_provider;
5858
use chroma_blockstore::{provider::BlockfileProvider, BlockfileWriterOptions};
59-
use chroma_index::sparse::{
60-
reader::{Score, SparseReader},
61-
writer::SparseWriter,
62-
};
59+
use chroma_index::sparse::{reader::SparseReader, types::Score, writer::SparseWriter};
6360
use chroma_types::SignedRoaringBitmap;
6461
use clap::Parser;
6562
use futures::{StreamExt, TryStreamExt};

rust/index/src/sparse/maxscore.rs

Lines changed: 56 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
use std::collections::BinaryHeap;
21
use std::sync::Arc;
32

43
use async_trait::async_trait;
54
use chroma_blockstore::{BlockfileFlusher, BlockfileReader, BlockfileWriter};
65
use chroma_error::{ChromaError, ErrorCodes};
76
use chroma_types::{
87
Directory, DirectoryBlock, SignedRoaringBitmap, SparsePostingBlock, DIRECTORY_PREFIX,
8+
MAX_BLOCK_ENTRIES,
99
};
1010
use dashmap::DashMap;
11+
use futures::StreamExt;
1112
use thiserror::Error;
1213
use uuid::Uuid;
1314

14-
use crate::sparse::types::{decode_u32, encode_u32};
15+
use crate::sparse::types::{decode_u32, encode_u32, Score, TopKHeap};
1516

1617
// ── Two-phase re-scoring ────────────────────────────────────────────
1718

@@ -39,22 +40,12 @@ pub async fn rescore_and_select(
3940
let doc_ids: Vec<u32> = candidates.iter().map(|s| s.offset).collect();
4041
let exact_scores = rescorer.rescore_batch(&doc_ids, query).await;
4142

42-
let mut heap: BinaryHeap<Score> = BinaryHeap::with_capacity(k);
43+
let mut heap = TopKHeap::new(k);
4344
for (i, &score) in exact_scores.iter().enumerate() {
44-
if heap.len() < k || score > heap.peek().map(|s| s.score).unwrap_or(f32::MIN) {
45-
heap.push(Score {
46-
score,
47-
offset: doc_ids[i],
48-
});
49-
if heap.len() > k {
50-
heap.pop();
51-
}
52-
}
45+
heap.push(score, doc_ids[i]);
5346
}
5447

55-
let mut results: Vec<Score> = heap.into_vec();
56-
results.sort_by(|a, b| b.score.total_cmp(&a.score).then(a.offset.cmp(&b.offset)));
57-
results
48+
heap.into_sorted_vec()
5849
}
5950

6051
const DEFAULT_BLOCK_SIZE: u32 = 1024;
@@ -75,34 +66,6 @@ impl ChromaError for BlockSparseError {
7566
}
7667
}
7768

78-
// ── Score type ──────────────────────────────────────────────────────
79-
80-
/// A (score, offset) pair with reversed ordering so that `BinaryHeap`
81-
/// acts as a min-heap: the *lowest* score sits at `peek()`, making it
82-
/// cheap to maintain a top-k set.
83-
#[derive(Debug, PartialEq)]
84-
pub struct Score {
85-
pub score: f32,
86-
pub offset: u32,
87-
}
88-
89-
impl Eq for Score {}
90-
91-
impl Ord for Score {
92-
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
93-
self.score
94-
.total_cmp(&other.score)
95-
.then(self.offset.cmp(&other.offset))
96-
.reverse()
97-
}
98-
}
99-
100-
impl PartialOrd for Score {
101-
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
102-
Some(self.cmp(other))
103-
}
104-
}
105-
10669
// ── BlockSparseFlusher ──────────────────────────────────────────────
10770

10871
pub struct BlockSparseFlusher {
@@ -146,7 +109,14 @@ impl<'me> BlockSparseWriter<'me> {
146109
}
147110

148111
pub fn with_block_size(mut self, block_size: u32) -> Self {
149-
self.block_size = block_size;
112+
if block_size > MAX_BLOCK_ENTRIES as u32 {
113+
tracing::warn!(
114+
requested = block_size,
115+
max = MAX_BLOCK_ENTRIES,
116+
"block_size exceeds MAX_BLOCK_ENTRIES, clamping"
117+
);
118+
}
119+
self.block_size = block_size.min(MAX_BLOCK_ENTRIES as u32);
150120
self
151121
}
152122

@@ -190,6 +160,12 @@ impl<'me> BlockSparseWriter<'me> {
190160
// blockfile's ordered_mutations requirement since all "d"-prefixed
191161
// directory keys sort after the plain base64 posting keys for
192162
// realistic dimension IDs.
163+
debug_assert!(
164+
encoded_dims
165+
.iter()
166+
.all(|(enc, _)| enc.as_str() < DIRECTORY_PREFIX),
167+
"encoded dimension prefix >= DIRECTORY_PREFIX; ordered_mutations invariant broken"
168+
);
193169
struct DirWork {
194170
prefix: String,
195171
directory: Option<Directory>,
@@ -199,14 +175,14 @@ impl<'me> BlockSparseWriter<'me> {
199175

200176
// ── Pass 1: posting blocks ─────────────────────────────────────
201177
for (encoded_dim, dimension_id) in &encoded_dims {
202-
let delta_updates = self.delta.remove(dimension_id);
203-
204-
if delta_updates.is_none() {
178+
let Some((_, updates)) = self.delta.remove(dimension_id) else {
205179
continue;
206-
}
207-
208-
let (_, updates) = delta_updates.unwrap();
180+
};
209181

182+
// NOTE: This is a full read-modify-write — all existing entries for
183+
// the dimension are loaded, merged with deltas, and rewritten. This
184+
// is O(n) per dimension regardless of delta size. A future optimization
185+
// could do in-place block patching for small deltas.
210186
let mut entries = std::collections::HashMap::new();
211187
let mut old_block_count = 0u32;
212188
let mut old_dir_part_count = 0u32;
@@ -385,7 +361,7 @@ impl PostingCursor {
385361

386362
while self.pos < offsets.len() {
387363
let off = offsets[self.pos];
388-
if passes_mask(off, mask) {
364+
if mask.contains(off) {
389365
return Some((off, values[self.pos]));
390366
}
391367
self.pos += 1;
@@ -491,7 +467,7 @@ impl PostingCursor {
491467
if doc > window_end {
492468
return;
493469
}
494-
if passes_mask(doc, mask) {
470+
if mask.contains(doc) {
495471
let idx = (doc - window_start) as usize;
496472
bitmap[idx >> 6] |= 1u64 << (idx & 63);
497473
accum[idx] += vals[self.pos] * query_weight;
@@ -560,13 +536,6 @@ impl PostingCursor {
560536
}
561537
}
562538

563-
fn passes_mask(offset: u32, mask: &SignedRoaringBitmap) -> bool {
564-
match mask {
565-
SignedRoaringBitmap::Include(rbm) => rbm.contains(offset),
566-
SignedRoaringBitmap::Exclude(rbm) => !rbm.contains(offset),
567-
}
568-
}
569-
570539
// ── BlockSparseReader ───────────────────────────────────────────────
571540

572541
#[derive(Clone)]
@@ -625,17 +594,24 @@ impl<'me> BlockSparseReader<'me> {
625594
Ok(parts.len())
626595
}
627596

597+
/// Return all dimension IDs stored in the blockfile.
598+
///
599+
/// Scans only directory entries (prefix "d"...) which are much fewer
600+
/// than posting blocks. A key-only scan API on BlockfileReader would
601+
/// avoid deserializing even the directory values.
628602
pub async fn get_all_dimension_ids(&self) -> Result<Vec<u32>, BlockSparseError> {
629-
let all: Vec<(&str, u32, SparsePostingBlock)> =
630-
self.posting_reader.get_range(.., ..).await?.collect();
603+
let dir_entries: Vec<(&str, u32, SparsePostingBlock)> = self
604+
.posting_reader
605+
.get_range(DIRECTORY_PREFIX.., ..)
606+
.await?
607+
.collect();
631608

632-
let mut dims: Vec<u32> = all
609+
let mut dims: Vec<u32> = dir_entries
633610
.iter()
634611
.filter_map(|(prefix, _, _)| {
635-
if prefix.starts_with(DIRECTORY_PREFIX) {
636-
return None;
637-
}
638-
decode_u32(prefix).ok()
612+
prefix
613+
.strip_prefix(DIRECTORY_PREFIX)
614+
.and_then(|rest| decode_u32(rest).ok())
639615
})
640616
.collect();
641617
dims.sort_unstable();
@@ -673,12 +649,19 @@ impl<'me> BlockSparseReader<'me> {
673649
let collected: Vec<(u32, f32)> = query_vector.into_iter().collect();
674650
let encoded_dims: Vec<String> = collected.iter().map(|(d, _)| encode_u32(*d)).collect();
675651

652+
// Open cursors for all query dimensions in parallel.
653+
let cursor_results: Vec<Result<Option<PostingCursor>, BlockSparseError>> =
654+
futures::stream::iter(encoded_dims.iter().map(|enc| self.open_cursor(enc)))
655+
.buffer_unordered(encoded_dims.len())
656+
.collect()
657+
.await;
658+
676659
let mut terms: Vec<TermState> = Vec::new();
677-
for (idx, &(_, query_weight)) in collected.iter().enumerate() {
678-
let encoded = &encoded_dims[idx];
679-
let Some(mut cursor) = self.open_cursor(encoded).await? else {
660+
for (idx, result) in cursor_results.into_iter().enumerate() {
661+
let Some(mut cursor) = result? else {
680662
continue;
681663
};
664+
let query_weight = collected[idx].1;
682665
cursor.advance(0, &mask);
683666
let max_score = query_weight * cursor.dimension_max();
684667
terms.push(TermState {
@@ -696,8 +679,8 @@ impl<'me> BlockSparseReader<'me> {
696679
terms.sort_by(|a, b| a.max_score.total_cmp(&b.max_score));
697680

698681
let k_usize = k as usize;
699-
let mut threshold = f32::MIN;
700-
let mut heap: BinaryHeap<Score> = BinaryHeap::with_capacity(k_usize);
682+
let mut heap = TopKHeap::new(k_usize);
683+
let mut threshold = heap.threshold();
701684

702685
const WINDOW_WIDTH: u32 = 4096;
703686
const BITMAP_WORDS: usize = (WINDOW_WIDTH as usize).div_ceil(64);
@@ -805,16 +788,7 @@ impl<'me> BlockSparseReader<'me> {
805788

806789
// Phase 3: extract to heap and reset accumulator
807790
for (ci, &doc) in cand_docs.iter().enumerate() {
808-
let score = cand_scores[ci];
809-
if score > threshold || heap.len() < k_usize {
810-
heap.push(Score { score, offset: doc });
811-
if heap.len() > k_usize {
812-
heap.pop();
813-
}
814-
if heap.len() == k_usize {
815-
threshold = heap.peek().map(|s| s.score).unwrap_or(f32::MIN);
816-
}
817-
}
791+
threshold = heap.push(cand_scores[ci], doc);
818792
}
819793

820794
// Zero accum slots + clear bitmap using the bitmap itself
@@ -834,16 +808,13 @@ impl<'me> BlockSparseReader<'me> {
834808
}
835809
}
836810

837-
let mut results: Vec<Score> = heap.into_vec();
838-
results.sort_by(|a, b| b.score.total_cmp(&a.score).then(a.offset.cmp(&b.offset)));
839-
Ok(results)
811+
Ok(heap.into_sorted_vec())
840812
}
841813
}
842814

843815
struct TermState {
844816
cursor: PostingCursor,
845817
query_weight: f32,
846-
#[allow(dead_code)]
847818
max_score: f32,
848819
window_score: f32,
849820
}
@@ -871,39 +842,6 @@ fn filter_competitive(cand_docs: &mut Vec<u32>, cand_scores: &mut Vec<f32>, cuto
871842
mod tests {
872843
use super::*;
873844

874-
#[test]
875-
fn score_min_heap_ordering() {
876-
let mut heap = BinaryHeap::new();
877-
heap.push(Score {
878-
score: 3.0,
879-
offset: 1,
880-
});
881-
heap.push(Score {
882-
score: 1.0,
883-
offset: 2,
884-
});
885-
heap.push(Score {
886-
score: 2.0,
887-
offset: 3,
888-
});
889-
assert_eq!(heap.peek().unwrap().score, 1.0);
890-
heap.pop();
891-
assert_eq!(heap.peek().unwrap().score, 2.0);
892-
}
893-
894-
#[test]
895-
fn score_tiebreak_by_offset() {
896-
let a = Score {
897-
score: 1.0,
898-
offset: 10,
899-
};
900-
let b = Score {
901-
score: 1.0,
902-
offset: 20,
903-
};
904-
assert!(a > b); // reversed: higher offset = "lower" priority
905-
}
906-
907845
#[test]
908846
fn filter_competitive_removes_below_cutoff() {
909847
let mut docs = vec![1, 2, 3, 4, 5];

0 commit comments

Comments
 (0)