1- use std:: collections:: BinaryHeap ;
21use std:: sync:: Arc ;
32
43use async_trait:: async_trait;
54use chroma_blockstore:: { BlockfileFlusher , BlockfileReader , BlockfileWriter } ;
65use chroma_error:: { ChromaError , ErrorCodes } ;
76use chroma_types:: {
87 Directory , DirectoryBlock , SignedRoaringBitmap , SparsePostingBlock , DIRECTORY_PREFIX ,
8+ MAX_BLOCK_ENTRIES ,
99} ;
1010use dashmap:: DashMap ;
11+ use futures:: StreamExt ;
1112use thiserror:: Error ;
1213use uuid:: Uuid ;
1314
14- use crate :: sparse:: types:: { decode_u32, encode_u32} ;
15+ use crate :: sparse:: types:: { decode_u32, encode_u32, Score , TopKHeap } ;
1516
1617// ── Two-phase re-scoring ────────────────────────────────────────────
1718
@@ -39,22 +40,12 @@ pub async fn rescore_and_select(
3940 let doc_ids: Vec < u32 > = candidates. iter ( ) . map ( |s| s. offset ) . collect ( ) ;
4041 let exact_scores = rescorer. rescore_batch ( & doc_ids, query) . await ;
4142
42- let mut heap: BinaryHeap < Score > = BinaryHeap :: with_capacity ( k) ;
43+ let mut heap = TopKHeap :: new ( k) ;
4344 for ( i, & score) in exact_scores. iter ( ) . enumerate ( ) {
44- if heap. len ( ) < k || score > heap. peek ( ) . map ( |s| s. score ) . unwrap_or ( f32:: MIN ) {
45- heap. push ( Score {
46- score,
47- offset : doc_ids[ i] ,
48- } ) ;
49- if heap. len ( ) > k {
50- heap. pop ( ) ;
51- }
52- }
45+ heap. push ( score, doc_ids[ i] ) ;
5346 }
5447
55- let mut results: Vec < Score > = heap. into_vec ( ) ;
56- results. sort_by ( |a, b| b. score . total_cmp ( & a. score ) . then ( a. offset . cmp ( & b. offset ) ) ) ;
57- results
48+ heap. into_sorted_vec ( )
5849}
5950
6051const DEFAULT_BLOCK_SIZE : u32 = 1024 ;
@@ -75,34 +66,6 @@ impl ChromaError for BlockSparseError {
7566 }
7667}
7768
78- // ── Score type ──────────────────────────────────────────────────────
79-
80- /// A (score, offset) pair with reversed ordering so that `BinaryHeap`
81- /// acts as a min-heap: the *lowest* score sits at `peek()`, making it
82- /// cheap to maintain a top-k set.
83- #[ derive( Debug , PartialEq ) ]
84- pub struct Score {
85- pub score : f32 ,
86- pub offset : u32 ,
87- }
88-
89- impl Eq for Score { }
90-
91- impl Ord for Score {
92- fn cmp ( & self , other : & Self ) -> std:: cmp:: Ordering {
93- self . score
94- . total_cmp ( & other. score )
95- . then ( self . offset . cmp ( & other. offset ) )
96- . reverse ( )
97- }
98- }
99-
100- impl PartialOrd for Score {
101- fn partial_cmp ( & self , other : & Self ) -> Option < std:: cmp:: Ordering > {
102- Some ( self . cmp ( other) )
103- }
104- }
105-
10669// ── BlockSparseFlusher ──────────────────────────────────────────────
10770
10871pub struct BlockSparseFlusher {
@@ -146,7 +109,14 @@ impl<'me> BlockSparseWriter<'me> {
146109 }
147110
148111 pub fn with_block_size ( mut self , block_size : u32 ) -> Self {
149- self . block_size = block_size;
112+ if block_size > MAX_BLOCK_ENTRIES as u32 {
113+ tracing:: warn!(
114+ requested = block_size,
115+ max = MAX_BLOCK_ENTRIES ,
116+ "block_size exceeds MAX_BLOCK_ENTRIES, clamping"
117+ ) ;
118+ }
119+ self . block_size = block_size. min ( MAX_BLOCK_ENTRIES as u32 ) ;
150120 self
151121 }
152122
@@ -190,6 +160,12 @@ impl<'me> BlockSparseWriter<'me> {
190160 // blockfile's ordered_mutations requirement since all "d"-prefixed
191161 // directory keys sort after the plain base64 posting keys for
192162 // realistic dimension IDs.
163+ debug_assert ! (
164+ encoded_dims
165+ . iter( )
166+ . all( |( enc, _) | enc. as_str( ) < DIRECTORY_PREFIX ) ,
167+ "encoded dimension prefix >= DIRECTORY_PREFIX; ordered_mutations invariant broken"
168+ ) ;
193169 struct DirWork {
194170 prefix : String ,
195171 directory : Option < Directory > ,
@@ -199,14 +175,14 @@ impl<'me> BlockSparseWriter<'me> {
199175
200176 // ── Pass 1: posting blocks ─────────────────────────────────────
201177 for ( encoded_dim, dimension_id) in & encoded_dims {
202- let delta_updates = self . delta . remove ( dimension_id) ;
203-
204- if delta_updates. is_none ( ) {
178+ let Some ( ( _, updates) ) = self . delta . remove ( dimension_id) else {
205179 continue ;
206- }
207-
208- let ( _, updates) = delta_updates. unwrap ( ) ;
180+ } ;
209181
182+ // NOTE: This is a full read-modify-write — all existing entries for
183+ // the dimension are loaded, merged with deltas, and rewritten. This
184+ // is O(n) per dimension regardless of delta size. A future optimization
185+ // could do in-place block patching for small deltas.
210186 let mut entries = std:: collections:: HashMap :: new ( ) ;
211187 let mut old_block_count = 0u32 ;
212188 let mut old_dir_part_count = 0u32 ;
@@ -385,7 +361,7 @@ impl PostingCursor {
385361
386362 while self . pos < offsets. len ( ) {
387363 let off = offsets[ self . pos ] ;
388- if passes_mask ( off, mask ) {
364+ if mask . contains ( off) {
389365 return Some ( ( off, values[ self . pos ] ) ) ;
390366 }
391367 self . pos += 1 ;
@@ -491,7 +467,7 @@ impl PostingCursor {
491467 if doc > window_end {
492468 return ;
493469 }
494- if passes_mask ( doc, mask ) {
470+ if mask . contains ( doc) {
495471 let idx = ( doc - window_start) as usize ;
496472 bitmap[ idx >> 6 ] |= 1u64 << ( idx & 63 ) ;
497473 accum[ idx] += vals[ self . pos ] * query_weight;
@@ -560,13 +536,6 @@ impl PostingCursor {
560536 }
561537}
562538
563- fn passes_mask ( offset : u32 , mask : & SignedRoaringBitmap ) -> bool {
564- match mask {
565- SignedRoaringBitmap :: Include ( rbm) => rbm. contains ( offset) ,
566- SignedRoaringBitmap :: Exclude ( rbm) => !rbm. contains ( offset) ,
567- }
568- }
569-
570539// ── BlockSparseReader ───────────────────────────────────────────────
571540
572541#[ derive( Clone ) ]
@@ -625,17 +594,24 @@ impl<'me> BlockSparseReader<'me> {
625594 Ok ( parts. len ( ) )
626595 }
627596
597+ /// Return all dimension IDs stored in the blockfile.
598+ ///
599+ /// Scans only directory entries (prefix "d"...) which are much fewer
600+ /// than posting blocks. A key-only scan API on BlockfileReader would
601+ /// avoid deserializing even the directory values.
628602 pub async fn get_all_dimension_ids ( & self ) -> Result < Vec < u32 > , BlockSparseError > {
629- let all: Vec < ( & str , u32 , SparsePostingBlock ) > =
630- self . posting_reader . get_range ( .., ..) . await ?. collect ( ) ;
603+ let dir_entries: Vec < ( & str , u32 , SparsePostingBlock ) > = self
604+ . posting_reader
605+ . get_range ( DIRECTORY_PREFIX .., ..)
606+ . await ?
607+ . collect ( ) ;
631608
632- let mut dims: Vec < u32 > = all
609+ let mut dims: Vec < u32 > = dir_entries
633610 . iter ( )
634611 . filter_map ( |( prefix, _, _) | {
635- if prefix. starts_with ( DIRECTORY_PREFIX ) {
636- return None ;
637- }
638- decode_u32 ( prefix) . ok ( )
612+ prefix
613+ . strip_prefix ( DIRECTORY_PREFIX )
614+ . and_then ( |rest| decode_u32 ( rest) . ok ( ) )
639615 } )
640616 . collect ( ) ;
641617 dims. sort_unstable ( ) ;
@@ -673,12 +649,19 @@ impl<'me> BlockSparseReader<'me> {
673649 let collected: Vec < ( u32 , f32 ) > = query_vector. into_iter ( ) . collect ( ) ;
674650 let encoded_dims: Vec < String > = collected. iter ( ) . map ( |( d, _) | encode_u32 ( * d) ) . collect ( ) ;
675651
652+ // Open cursors for all query dimensions in parallel.
653+ let cursor_results: Vec < Result < Option < PostingCursor > , BlockSparseError > > =
654+ futures:: stream:: iter ( encoded_dims. iter ( ) . map ( |enc| self . open_cursor ( enc) ) )
655+ . buffer_unordered ( encoded_dims. len ( ) )
656+ . collect ( )
657+ . await ;
658+
676659 let mut terms: Vec < TermState > = Vec :: new ( ) ;
677- for ( idx, & ( _, query_weight) ) in collected. iter ( ) . enumerate ( ) {
678- let encoded = & encoded_dims[ idx] ;
679- let Some ( mut cursor) = self . open_cursor ( encoded) . await ? else {
660+ for ( idx, result) in cursor_results. into_iter ( ) . enumerate ( ) {
661+ let Some ( mut cursor) = result? else {
680662 continue ;
681663 } ;
664+ let query_weight = collected[ idx] . 1 ;
682665 cursor. advance ( 0 , & mask) ;
683666 let max_score = query_weight * cursor. dimension_max ( ) ;
684667 terms. push ( TermState {
@@ -696,8 +679,8 @@ impl<'me> BlockSparseReader<'me> {
696679 terms. sort_by ( |a, b| a. max_score . total_cmp ( & b. max_score ) ) ;
697680
698681 let k_usize = k as usize ;
699- let mut threshold = f32 :: MIN ;
700- let mut heap : BinaryHeap < Score > = BinaryHeap :: with_capacity ( k_usize ) ;
682+ let mut heap = TopKHeap :: new ( k_usize ) ;
683+ let mut threshold = heap . threshold ( ) ;
701684
702685 const WINDOW_WIDTH : u32 = 4096 ;
703686 const BITMAP_WORDS : usize = ( WINDOW_WIDTH as usize ) . div_ceil ( 64 ) ;
@@ -805,16 +788,7 @@ impl<'me> BlockSparseReader<'me> {
805788
806789 // Phase 3: extract to heap and reset accumulator
807790 for ( ci, & doc) in cand_docs. iter ( ) . enumerate ( ) {
808- let score = cand_scores[ ci] ;
809- if score > threshold || heap. len ( ) < k_usize {
810- heap. push ( Score { score, offset : doc } ) ;
811- if heap. len ( ) > k_usize {
812- heap. pop ( ) ;
813- }
814- if heap. len ( ) == k_usize {
815- threshold = heap. peek ( ) . map ( |s| s. score ) . unwrap_or ( f32:: MIN ) ;
816- }
817- }
791+ threshold = heap. push ( cand_scores[ ci] , doc) ;
818792 }
819793
820794 // Zero accum slots + clear bitmap using the bitmap itself
@@ -834,16 +808,13 @@ impl<'me> BlockSparseReader<'me> {
834808 }
835809 }
836810
837- let mut results: Vec < Score > = heap. into_vec ( ) ;
838- results. sort_by ( |a, b| b. score . total_cmp ( & a. score ) . then ( a. offset . cmp ( & b. offset ) ) ) ;
839- Ok ( results)
811+ Ok ( heap. into_sorted_vec ( ) )
840812 }
841813}
842814
843815struct TermState {
844816 cursor : PostingCursor ,
845817 query_weight : f32 ,
846- #[ allow( dead_code) ]
847818 max_score : f32 ,
848819 window_score : f32 ,
849820}
@@ -871,39 +842,6 @@ fn filter_competitive(cand_docs: &mut Vec<u32>, cand_scores: &mut Vec<f32>, cuto
871842mod tests {
872843 use super :: * ;
873844
874- #[ test]
875- fn score_min_heap_ordering ( ) {
876- let mut heap = BinaryHeap :: new ( ) ;
877- heap. push ( Score {
878- score : 3.0 ,
879- offset : 1 ,
880- } ) ;
881- heap. push ( Score {
882- score : 1.0 ,
883- offset : 2 ,
884- } ) ;
885- heap. push ( Score {
886- score : 2.0 ,
887- offset : 3 ,
888- } ) ;
889- assert_eq ! ( heap. peek( ) . unwrap( ) . score, 1.0 ) ;
890- heap. pop ( ) ;
891- assert_eq ! ( heap. peek( ) . unwrap( ) . score, 2.0 ) ;
892- }
893-
894- #[ test]
895- fn score_tiebreak_by_offset ( ) {
896- let a = Score {
897- score : 1.0 ,
898- offset : 10 ,
899- } ;
900- let b = Score {
901- score : 1.0 ,
902- offset : 20 ,
903- } ;
904- assert ! ( a > b) ; // reversed: higher offset = "lower" priority
905- }
906-
907845 #[ test]
908846 fn filter_competitive_removes_below_cutoff ( ) {
909847 let mut docs = vec ! [ 1 , 2 , 3 , 4 , 5 ] ;
0 commit comments