From 33436c050b86dd4f6154f5c79ccd914c7624be45 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 12 May 2026 12:57:19 -0700 Subject: [PATCH 1/9] Checkpoint. --- diskann-benchmark-runner/dev/main.rs | 10 +- diskann-benchmark-runner/src/any.rs | 12 +- diskann-benchmark-runner/src/app.rs | 54 +++---- diskann-benchmark-runner/src/input.rs | 12 +- .../src/internal/regression.rs | 13 +- diskann-benchmark-runner/src/jobs.rs | 8 +- diskann-benchmark-runner/src/lib.rs | 1 + diskann-benchmark-runner/src/registry.rs | 142 ++++++++++-------- diskann-benchmark-runner/src/test/mod.rs | 23 ++- diskann-benchmark-simd/src/bin.rs | 10 +- diskann-benchmark-simd/src/lib.rs | 89 +++++------ .../src/backend/disk_index/benchmarks.rs | 13 +- .../src/backend/disk_index/mod.rs | 10 +- .../src/backend/exhaustive/minmax.rs | 18 ++- .../src/backend/exhaustive/mod.rs | 11 +- .../src/backend/exhaustive/product.rs | 10 +- .../src/backend/exhaustive/spherical.rs | 18 ++- .../src/backend/filters/benchmark.rs | 6 +- diskann-benchmark/src/backend/filters/mod.rs | 6 +- .../src/backend/index/benchmarks.rs | 43 +++--- diskann-benchmark/src/backend/index/mod.rs | 6 +- .../src/backend/index/product.rs | 16 +- diskann-benchmark/src/backend/index/scalar.rs | 16 +- .../src/backend/index/search/plugins.rs | 2 +- .../src/backend/index/spherical.rs | 20 +-- diskann-benchmark/src/backend/mod.rs | 13 +- diskann-benchmark/src/inputs/disk.rs | 7 - diskann-benchmark/src/inputs/exhaustive.rs | 9 -- diskann-benchmark/src/inputs/filters.rs | 7 - diskann-benchmark/src/inputs/graph_index.rs | 11 -- diskann-benchmark/src/inputs/mod.rs | 10 -- diskann-benchmark/src/main.rs | 10 +- diskann-benchmark/src/utils/mod.rs | 6 +- 33 files changed, 318 insertions(+), 324 deletions(-) diff --git a/diskann-benchmark-runner/dev/main.rs b/diskann-benchmark-runner/dev/main.rs index d94da95fc..9f7dae60d 100644 --- a/diskann-benchmark-runner/dev/main.rs +++ b/diskann-benchmark-runner/dev/main.rs @@ -9,12 +9,8 @@ fn main() -> anyhow::Result<()> { // Parse the command line options. let app = App::parse(); - // Gather the test inputs and outputs. - let mut inputs = registry::Inputs::new(); - diskann_benchmark_runner::test::register_inputs(&mut inputs)?; + let mut registry = registry::Benchmarks::new(); + diskann_benchmark_runner::test::register_benchmarks(&mut registry); - let mut benchmarks = registry::Benchmarks::new(); - diskann_benchmark_runner::test::register_benchmarks(&mut benchmarks); - - app.run(&inputs, &benchmarks, &mut output::default()) + app.run(®istry, &mut output::default()) } diff --git a/diskann-benchmark-runner/src/any.rs b/diskann-benchmark-runner/src/any.rs index b06cfdad3..683614f6b 100644 --- a/diskann-benchmark-runner/src/any.rs +++ b/diskann-benchmark-runner/src/any.rs @@ -3,7 +3,10 @@ * Licensed under the MIT license. */ -use crate::dispatcher::{DispatchRule, FailureScore, MatchScore}; +use crate::{ + dispatcher::{DispatchRule, FailureScore, MatchScore}, + Input, +}; /// An refinement of [`std::any::Any`] with an associated name (tag) and serialization. /// @@ -33,6 +36,13 @@ impl Any { } } + pub fn input(any: T) -> Self + where + T: Input + serde::Serialize + std::fmt::Debug + 'static, + { + Self::new(any, T::tag()) + } + /// A lower level API for constructing an [`Any`] that decouples the serialized /// representation from the inmemory representation. /// diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index ad5ff3936..249907858 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -38,16 +38,13 @@ //! use diskann_benchmark_runner::{App, registry}; //! //! fn main() -> anyhow::Result<()> { -//! let mut inputs = registry::Inputs::new(); -//! // inputs.register::()?; -//! -//! let mut benchmarks = registry::Benchmarks::new(); -//! // benchmarks.register::("my-bench"); -//! // benchmarks.register_regression::("my-regression"); +//! let mut registry = registry::Registry::new(); +//! // registry.register::("my-bench"); +//! // registry.register_regression::("my-regression"); //! //! let app = App::parse(); //! let mut output = diskann_benchmark_runner::output::default(); -//! app.run(&inputs, &benchmarks, &mut output) +//! app.run(®istry, &mut output) //! } //! ``` //! @@ -192,15 +189,14 @@ impl App { /// Run the application using the registered `inputs` and `benchmarks`. pub fn run( &self, - inputs: ®istry::Inputs, - benchmarks: ®istry::Benchmarks, + registry: ®istry::Registry, mut output: &mut dyn Output, ) -> anyhow::Result<()> { match &self.command { // If a named benchmark isn't given, then list the available benchmarks. Commands::Inputs { describe } => { if let Some(describe) = describe { - if let Some(input) = inputs.get(describe) { + if let Some(input) = registry.input(describe) { let repr = jobs::Unprocessed::format_input(input)?; writeln!( output, @@ -217,7 +213,7 @@ impl App { } writeln!(output, "Available input kinds are listed below:")?; - let mut tags: Vec<_> = inputs.tags().collect(); + let mut tags: Vec<_> = registry.tags().collect(); tags.sort(); for i in tags.iter() { writeln!(output, " {}", i)?; @@ -226,7 +222,7 @@ impl App { // List the available benchmarks. Commands::Benchmarks {} => { writeln!(output, "Registered Benchmarks:")?; - for (name, description) in benchmarks.names() { + for (name, description) in registry.names() { write!(output, " {name}:")?; if description.is_empty() { writeln!(output)?; @@ -248,11 +244,11 @@ impl App { allow_debug, } => { // Parse and validate the input. - let run = Jobs::load(input_file, inputs)?; + let run = Jobs::load(input_file, registry)?; // Check if we have a match for each benchmark. for job in run.jobs().iter() { const MAX_METHODS: usize = 3; - if let Err(mismatches) = benchmarks.debug(job, MAX_METHODS) { + if let Err(mismatches) = registry.debug(job, MAX_METHODS) { let repr = serde_json::to_string_pretty(&job.serialize()?)?; writeln!( @@ -314,7 +310,7 @@ impl App { // Run the specified job. let checkpoint = Checkpoint::new(&serialized, &results, output_file)?; - let r = benchmarks.call(job, checkpoint, output)?; + let r = registry.call(job, checkpoint, output)?; // Collect the results results.push(r); @@ -324,7 +320,7 @@ impl App { } } // Extensions - Commands::Check(check) => return self.check(check, inputs, benchmarks, output), + Commands::Check(check) => return self.check(check, registry, output), }; Ok(()) } @@ -333,8 +329,7 @@ impl App { fn check( &self, check: &Check, - inputs: ®istry::Inputs, - benchmarks: ®istry::Benchmarks, + registry: ®istry::Registry, mut output: &mut dyn Output, ) -> anyhow::Result<()> { match check { @@ -350,7 +345,7 @@ impl App { Ok(()) } Check::Tolerances { describe } => { - let tolerances = benchmarks.tolerances(); + let tolerances = registry.tolerances(); match describe { Some(name) => match tolerances.get(&**name) { @@ -405,12 +400,7 @@ impl App { tolerances, input_file, } => { - // For verification - we merely check that we can successfully construct - // the regression `Checks` struct. It performs all the necessary preflight - // checks. - let benchmarks = benchmarks.tolerances(); - let _ = - internal::regression::Checks::new(tolerances, input_file, inputs, &benchmarks)?; + let _ = internal::regression::Checks::new(tolerances, input_file, registry)?; Ok(()) } Check::Run { @@ -420,9 +410,7 @@ impl App { after, output_file, } => { - let registered = benchmarks.tolerances(); - let checks = - internal::regression::Checks::new(tolerances, input_file, inputs, ®istered)?; + let checks = internal::regression::Checks::new(tolerances, input_file, registry)?; let jobs = checks.jobs(before, after)?; jobs.run(output, output_file.as_deref())?; Ok(()) @@ -605,13 +593,9 @@ mod tests { fn run(&self, tempdir: &Path) { let apps = self.parse_stdin(tempdir); - // Register inputs - let mut inputs = registry::Inputs::new(); - crate::test::register_inputs(&mut inputs).unwrap(); - // Register outputs - let mut benchmarks = registry::Benchmarks::new(); - crate::test::register_benchmarks(&mut benchmarks); + let mut registry = registry::Registry::new(); + crate::test::register_benchmarks(&mut registry).unwrap(); // Run each app invocation - collecting the last output into a buffer. // @@ -631,7 +615,7 @@ mod tests { &mut crate::output::Sink::new() }; - if let Err(err) = app.run(&inputs, &benchmarks, b) { + if let Err(err) = app.run(®istry, b) { if is_last { write!(b, "{:?}", err).unwrap(); } else { diff --git a/diskann-benchmark-runner/src/input.rs b/diskann-benchmark-runner/src/input.rs index 4f8b1523e..07e780b8c 100644 --- a/diskann-benchmark-runner/src/input.rs +++ b/diskann-benchmark-runner/src/input.rs @@ -110,11 +110,15 @@ pub(crate) trait DynInput { checker: &mut Checker, ) -> anyhow::Result; fn example(&self) -> anyhow::Result; + + // reflection + fn as_any(&self) -> &dyn std::any::Any; + fn type_name(&self) -> &'static str; } impl DynInput for Wrapper where - T: Input, + T: Input + 'static, { fn tag(&self) -> &'static str { T::tag() @@ -129,4 +133,10 @@ where fn example(&self) -> anyhow::Result { T::example() } + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn type_name(&self) -> &'static str { + std::any::type_name::() + } } diff --git a/diskann-benchmark-runner/src/internal/regression.rs b/diskann-benchmark-runner/src/internal/regression.rs index f9bc12061..c1d1838ee 100644 --- a/diskann-benchmark-runner/src/internal/regression.rs +++ b/diskann-benchmark-runner/src/internal/regression.rs @@ -126,8 +126,7 @@ impl<'a> Checks<'a> { pub(crate) fn new( tolerances: &Path, input_file: &Path, - inputs: ®istry::Inputs, - entries: &'a HashMap<&'static str, registry::RegisteredTolerance<'a>>, + registry: &'a registry::Registry, ) -> anyhow::Result { // Load the raw input file. let partial = jobs::Partial::load(input_file)?; @@ -135,11 +134,11 @@ impl<'a> Checks<'a> { // Parse and validate the raw jobs against the registered inputs. // // This preserves the ordering of the jobs. - let inputs = jobs::Jobs::parse(&partial, inputs)?; + let inputs = jobs::Jobs::parse(&partial, registry)?; // Now that the inputs have been fully parsed and validated, we then check that we // can load the raw tolerance file. - let parsed = Raw::load(tolerances)?.parse(entries)?; + let parsed = Raw::load(tolerances)?.parse(®istry.tolerances())?; Self::match_all(parsed, partial, inputs) } @@ -298,7 +297,7 @@ impl Raw { fn parse<'a>( self, - entries: &'a HashMap<&'static str, registry::RegisteredTolerance<'a>>, + entries: &HashMap<&'static str, registry::RegisteredTolerance<'a>>, ) -> anyhow::Result> { // Attempt to parse raw tolerances into registered tolerance inputs. let num_checks = self.checks.len(); @@ -356,7 +355,7 @@ impl Raw { .with_context(context)?; Ok(ParsedInner { - entry, + entry: entry.clone(), tolerance: Rc::new(tolerance), input: unprocessed.input, }) @@ -382,7 +381,7 @@ impl Raw { /// * The tag in `input` exists within at least one of the regressions in `entry`. #[derive(Debug)] struct ParsedInner<'a> { - entry: &'a registry::RegisteredTolerance<'a>, + entry: registry::RegisteredTolerance<'a>, tolerance: Rc, input: jobs::Unprocessed, } diff --git a/diskann-benchmark-runner/src/jobs.rs b/diskann-benchmark-runner/src/jobs.rs index c7a4d2108..e0374e7f8 100644 --- a/diskann-benchmark-runner/src/jobs.rs +++ b/diskann-benchmark-runner/src/jobs.rs @@ -8,7 +8,7 @@ use std::path::{Path, PathBuf}; use anyhow::Context; use serde::{Deserialize, Serialize}; -use crate::{checker::Checker, input, registry, Any}; +use crate::{checker::Checker, input, Registry, Any}; #[derive(Debug)] pub(crate) struct Jobs { @@ -33,14 +33,14 @@ impl Jobs { /// the post-load validation of the requested runs, including: /// /// * Resolution of input files. - pub(crate) fn load(path: &Path, registry: ®istry::Inputs) -> anyhow::Result { + pub(crate) fn load(path: &Path, registry: &Registry) -> anyhow::Result { Self::parse(&Partial::load(path)?, registry) } /// Parse `self` from a [`Partial`]. /// /// This method also perform deserialization checks on the parsed inputs. - pub(crate) fn parse(partial: &Partial, registry: ®istry::Inputs) -> anyhow::Result { + pub(crate) fn parse(partial: &Partial, registry: &Registry) -> anyhow::Result { let mut checker = Checker::new( partial .search_directories @@ -65,7 +65,7 @@ impl Jobs { }; let input = registry - .get(&unprocessed.tag) + .input(&unprocessed.tag) .ok_or_else(|| { anyhow::anyhow!("Unrecognized input tag: \"{}\"", unprocessed.tag) }) diff --git a/diskann-benchmark-runner/src/lib.rs b/diskann-benchmark-runner/src/lib.rs index 9b8cf3cdb..a703d4534 100644 --- a/diskann-benchmark-runner/src/lib.rs +++ b/diskann-benchmark-runner/src/lib.rs @@ -27,6 +27,7 @@ pub use checker::{CheckDeserialization, Checker}; pub use input::Input; pub use output::Output; pub use result::Checkpoint; +pub use registry::Registry; #[cfg(any(test, feature = "test-app"))] pub mod test; diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index 5d8c7366c..8627eb30e 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -13,61 +13,6 @@ use crate::{ input, Any, Checkpoint, Input, Output, }; -/// A collection of [`crate::Input`]. -pub struct Inputs { - // Inputs keyed by their tag type. - inputs: HashMap<&'static str, Box>, -} - -impl Inputs { - /// Construct a new empty [`Inputs`] registry. - pub fn new() -> Self { - Self { - inputs: HashMap::new(), - } - } - - /// Return the input with the registered `tag` if present. Otherwise, return `None`. - pub fn get(&self, tag: &str) -> Option> { - self.inputs.get(tag).map(|v| input::Registered(&**v)) - } - - /// Register the [`Input`] `T` in the registry. - /// - /// Returns an error if any other input with the same [`Input::tag()`] has been registered - /// while leaving the underlying registry unchanged. - pub fn register(&mut self) -> anyhow::Result<()> - where - T: Input + 'static, - { - let tag = T::tag(); - match self.inputs.entry(tag) { - Entry::Vacant(entry) => { - entry.insert(Box::new(crate::input::Wrapper::::new())); - Ok(()) - } - Entry::Occupied(_) => { - #[derive(Debug, Error)] - #[error("An input with the tag \"{}\" already exists", self.0)] - struct AlreadyExists(&'static str); - - Err(anyhow::anyhow!(AlreadyExists(tag))) - } - } - } - - /// Return an iterator over all registered input tags in an unspecified order. - pub fn tags(&self) -> impl ExactSizeIterator + use<'_> { - self.inputs.keys().copied() - } -} - -impl Default for Inputs { - fn default() -> Self { - Self::new() - } -} - /// A registered benchmark entry: a name paired with a type-erased benchmark. pub(crate) struct RegisteredBenchmark { name: String, @@ -95,23 +40,47 @@ impl RegisteredBenchmark { } /// A collection of registered benchmarks. -pub struct Benchmarks { +pub struct Registry { + // Inputs keyed by their tag type. + inputs: HashMap<&'static str, Box>, benchmarks: Vec, } -impl Benchmarks { +impl Registry { /// Return a new empty registry. pub fn new() -> Self { Self { + inputs: HashMap::new(), benchmarks: Vec::new(), } } + /// Return the input with the registered `tag` if present. Otherwise, return `None`. + /// + /// Inputs are automatically registered as a side-effect of: + /// + /// * [`register`](Self::register) + /// * [`register_regression`](Self::register_regression) + pub fn input(&self, tag: &str) -> Option> { + self.inputs.get(tag).map(|t| input::Registered(&**t)) + } + + /// Return an iterator over all registered input tags in an unspecified order. + pub fn input_tags(&self) -> impl ExactSizeIterator + use<'_> { + self.inputs.keys().copied() + } + /// Register a new benchmark with the given name. - pub fn register(&mut self, name: impl Into, benchmark: T) + pub fn register( + &mut self, + name: impl Into, + benchmark: T, + ) -> Result<(), RegistryError> where T: Benchmark, { + self.register_input::()?; + self.benchmarks.push(RegisteredBenchmark { name: name.into(), benchmark: Box::new(benchmark::internal::Wrapper::::new( @@ -119,6 +88,7 @@ impl Benchmarks { benchmark::internal::NoRegression, )), }); + Ok(()) } /// Return an iterator over registered benchmark names and their descriptions. @@ -207,6 +177,33 @@ impl Benchmarks { .map(|(entry, _)| entry) } + fn register_input(&mut self) -> Result<(), RegistryError> + where + T: Input + 'static, + { + let tag = T::tag(); + let wrapper = crate::input::Wrapper::::new(); + match self.inputs.entry(tag) { + Entry::Vacant(v) => { + v.insert(Box::new(wrapper)); + Ok(()) + } + Entry::Occupied(o) => { + use input::DynInput; + + if o.get().as_any().is::>() { + Ok(()) + } else { + Err(RegistryError { + tag, + existing: o.get().type_name(), + new: wrapper.type_name(), + }) + } + } + } + } + //-------------------// // Regression Checks // //-------------------// @@ -215,10 +212,16 @@ impl Benchmarks { /// /// Upon registration, the associated [`Regression::Tolerances`] input and the benchmark /// itself will be reachable via [`Check`](crate::app::Check). - pub fn register_regression(&mut self, name: impl Into, benchmark: T) + pub fn register_regression( + &mut self, + name: impl Into, + benchmark: T, + ) -> Result<(), RegistryError> where T: Regression, { + self.register_input::()?; + let registered = benchmark::internal::Wrapper::::new( benchmark, benchmark::internal::WithRegression, @@ -227,6 +230,8 @@ impl Benchmarks { name: name.into(), benchmark: Box::new(registered), }); + + Ok(()) } /// Return a collection of all tolerance related inputs, keyed by the input tag type @@ -261,12 +266,27 @@ impl Benchmarks { } } -impl Default for Benchmarks { +impl Default for Registry { fn default() -> Self { Self::new() } } +/// Error for [`Registry::register`] or [`Registry::register_regression`]. +#[derive(Debug, Error)] +#[error( + "A different input with tag \"{}\" was already registered. Existing type: \"{}\". New type: \"{}\"", + self.tag, + self.existing, + self.new, +)] +pub struct RegistryError { + tag: &'static str, + existing: &'static str, + new: &'static str, +} + + /// Document the reason for a method matching failure. pub struct Mismatch { method: String, @@ -319,7 +339,7 @@ impl RegressionBenchmark<'_> { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub(crate) struct RegisteredTolerance<'a> { /// The tolerance parser. pub(crate) tolerance: input::Registered<'a>, diff --git a/diskann-benchmark-runner/src/test/mod.rs b/diskann-benchmark-runner/src/test/mod.rs index ea9853e5e..0b23b7600 100644 --- a/diskann-benchmark-runner/src/test/mod.rs +++ b/diskann-benchmark-runner/src/test/mod.rs @@ -15,20 +15,17 @@ pub(crate) use typed::TypeInput; // API // ///////// -pub fn register_inputs(inputs: &mut registry::Inputs) -> anyhow::Result<()> { - inputs.register::()?; - inputs.register::()?; - Ok(()) -} - -pub fn register_benchmarks(benchmarks: &mut registry::Benchmarks) { - benchmarks.register_regression("type-bench-f32", typed::TypeBench::::new()); - benchmarks.register_regression("type-bench-i8", typed::TypeBench::::new()); - benchmarks.register_regression( +pub fn register_benchmarks( + registry: &mut registry::Registry, +) -> Result<(), registry::RegistryError> { + registry.register_regression("type-bench-f32", typed::TypeBench::::new())?; + registry.register_regression("type-bench-i8", typed::TypeBench::::new())?; + registry.register_regression( "exact-type-bench-f32-1000", typed::ExactTypeBench::::new(), - ); + )?; - benchmarks.register("simple-bench", dim::SimpleBench); - benchmarks.register_regression("dim-bench", dim::DimBench); + registry.register("simple-bench", dim::SimpleBench)?; + registry.register_regression("dim-bench", dim::DimBench)?; + Ok(()) } diff --git a/diskann-benchmark-simd/src/bin.rs b/diskann-benchmark-simd/src/bin.rs index 3bdae01af..6ed64056d 100644 --- a/diskann-benchmark-simd/src/bin.rs +++ b/diskann-benchmark-simd/src/bin.rs @@ -13,15 +13,11 @@ pub fn main() -> anyhow::Result<()> { } fn main_inner(app: &App, output: &mut dyn Output) -> anyhow::Result<()> { - // Register inputs and benchmarks. - let mut inputs = registry::Inputs::new(); - inputs.register::()?; - - let mut benchmarks = registry::Benchmarks::new(); - register(&mut benchmarks); + let mut registry = registry::Registry::new(); + register(&mut registry)?; // Here we go! - app.run(&inputs, &benchmarks, output) + app.run(®istry, output) } /////////// diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index d6d0f86bb..830731df1 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -34,8 +34,8 @@ use diskann_benchmark_runner::{ // Public API // //////////////// -pub fn register(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) { - register_benchmarks_impl(dispatcher) +pub fn register(registry: &mut diskann_benchmark_runner::registry::Registry) -> anyhow::Result<()> { + Ok(register_benchmarks_impl(registry)?) } /////////// @@ -302,105 +302,108 @@ impl std::fmt::Display for CheckResult { // Benchmark Registration // //////////////////////////// -fn register_benchmarks_impl(dispatcher: &mut diskann_benchmark_runner::registry::Benchmarks) { +fn register_benchmarks_impl( + registry: &mut diskann_benchmark_runner::registry::Registry, +) -> Result<(), diskann_benchmark_runner::registry::RegistryError> { // x86-64-v4 #[cfg(target_arch = "x86_64")] { - dispatcher.register_regression( + registry.register_regression( "simd-op-f32xf32-x86_64_V4", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-f16xf16-x86_64_V4", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-u8xu8-x86_64_V4", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-i8xi8-x86_64_V4", Kernel::::new(), - ); + )?; } // x86-64-v3 #[cfg(target_arch = "x86_64")] { - dispatcher.register_regression( + registry.register_regression( "simd-op-f32xf32-x86_64_V3", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-f16xf16-x86_64_V3", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-u8xu8-x86_64_V3", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-i8xi8-x86_64_V3", Kernel::::new(), - ); + )?; } // aarch64-neon #[cfg(target_arch = "aarch64")] { - dispatcher.register_regression( + registry.register_regression( "simd-op-f32xf32-aarch64_neon", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-f16xf16-aarch64_neon", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-u8xu8-aarch64_neon", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-i8xi8-aarch64_neon", Kernel::::new(), - ); + )?; } // scalar - dispatcher.register_regression( + registry.register_regression( "simd-op-f32xf32-scalar", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-f16xf16-scalar", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-u8xu8-scalar", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-i8xi8-scalar", Kernel::::new(), - ); + )?; // reference - dispatcher.register_regression( + registry.register_regression( "simd-op-f32xf32-reference", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-f16xf16-reference", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-u8xu8-reference", Kernel::::new(), - ); - dispatcher.register_regression( + )?; + registry.register_regression( "simd-op-i8xi8-reference", Kernel::::new(), - ); + )?; + Ok(()) } ////////////// diff --git a/diskann-benchmark/src/backend/disk_index/benchmarks.rs b/diskann-benchmark/src/backend/disk_index/benchmarks.rs index 6c5298dd8..1b0a381c7 100644 --- a/diskann-benchmark/src/backend/disk_index/benchmarks.rs +++ b/diskann-benchmark/src/backend/disk_index/benchmarks.rs @@ -121,11 +121,14 @@ where // Benchmark Registration // //////////////////////////// -pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { - benchmarks.register_regression("disk-index-f32", DiskIndex::::new()); - benchmarks.register_regression("disk-index-f16", DiskIndex::::new()); - benchmarks.register_regression("disk-index-u8", DiskIndex::::new()); - benchmarks.register_regression("disk-index-i8", DiskIndex::::new()); +pub(super) fn register_benchmarks( + registry: &mut diskann_benchmark_runner::registry::Registry, +) -> anyhow::Result<()> { + registry.register_regression("disk-index-f32", DiskIndex::::new())?; + registry.register_regression("disk-index-f16", DiskIndex::::new())?; + registry.register_regression("disk-index-u8", DiskIndex::::new())?; + registry.register_regression("disk-index-i8", DiskIndex::::new())?; + Ok(()) } ///////////////////////// diff --git a/diskann-benchmark/src/backend/disk_index/mod.rs b/diskann-benchmark/src/backend/disk_index/mod.rs index 2fcd66136..dac2b15b1 100644 --- a/diskann-benchmark/src/backend/disk_index/mod.rs +++ b/diskann-benchmark/src/backend/disk_index/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Benchmarks; +use diskann_benchmark_runner::registry::Registry; cfg_if::cfg_if! { if #[cfg(feature = "disk-index")] { @@ -13,8 +13,8 @@ cfg_if::cfg_if! { mod json_spancollector; /// Register disk index benchmarks when the `disk-index` feature is enabled. - pub(crate) fn register_benchmarks(registry: &mut Benchmarks) { - benchmarks::register_benchmarks(registry); + pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + benchmarks::register_benchmarks(registry) } } else { crate::utils::stub_impl!( @@ -23,8 +23,8 @@ cfg_if::cfg_if! { ); /// Register a stub that guides users to enable the `disk-index` feature. - pub(crate) fn register_benchmarks(registry: &mut Benchmarks) { - imp::register("disk-index", registry); + pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + imp::register("disk-index", registry) } } } diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index 3516ab568..ecd1f1eb7 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Benchmarks; +use diskann_benchmark_runner::registry::Registry; const NAME: &str = "minmax-exhaustive-search"; @@ -11,17 +11,19 @@ crate::utils::stub_impl!("minmax-quantization", inputs::exhaustive::MinMax); // MinMax - requires feature "minmax-quantization" #[cfg(feature = "minmax-quantization")] -pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register(NAME, imp::MinMaxQ::<1>); - benchmarks.register(NAME, imp::MinMaxQ::<2>); - benchmarks.register(NAME, imp::MinMaxQ::<4>); - benchmarks.register(NAME, imp::MinMaxQ::<8>); +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + registry.register(NAME, imp::MinMaxQ::<1>)?; + registry.register(NAME, imp::MinMaxQ::<2>)?; + registry.register(NAME, imp::MinMaxQ::<4>)?; + registry.register(NAME, imp::MinMaxQ::<8>)?; + + Ok(()) } // Stub implementation #[cfg(not(feature = "minmax-quantization"))] -pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - imp::register(NAME, benchmarks) +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + imp::register(NAME, registry) } ///////////// diff --git a/diskann-benchmark/src/backend/exhaustive/mod.rs b/diskann-benchmark/src/backend/exhaustive/mod.rs index c756c3451..0237262ff 100644 --- a/diskann-benchmark/src/backend/exhaustive/mod.rs +++ b/diskann-benchmark/src/backend/exhaustive/mod.rs @@ -14,10 +14,11 @@ mod minmax; mod product; mod spherical; -use diskann_benchmark_runner::registry::Benchmarks; +use diskann_benchmark_runner::registry::Registry; -pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { - spherical::register_benchmarks(benchmarks); - minmax::register_benchmarks(benchmarks); - product::register_benchmarks(benchmarks); +pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + spherical::register_benchmarks(registry)?; + minmax::register_benchmarks(registry)?; + product::register_benchmarks(registry)?; + Ok(()) } diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 78711626e..5c41f2928 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -3,18 +3,20 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Benchmarks; +use diskann_benchmark_runner::registry::Registry; const NAME: &str = "product-exhaustive-search"; crate::utils::stub_impl!("product-quantization", inputs::exhaustive::Product); -pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { #[cfg(feature = "product-quantization")] - benchmarks.register(NAME, imp::ProductQ); + registry.register(NAME, imp::ProductQ)?; #[cfg(not(feature = "product-quantization"))] - imp::register(NAME, benchmarks) + imp::register(NAME, registry)?; + + Ok(()) } ////////////// diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index b7dfd69b0..08fe66576 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Benchmarks; +use diskann_benchmark_runner::registry::Registry; const NAME: &str = "spherical-exhaustive-search"; @@ -11,17 +11,19 @@ crate::utils::stub_impl!("spherical-quantization", inputs::exhaustive::Spherical // Spherical - requires feature "spherical-quantization" #[cfg(feature = "spherical-quantization")] -pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register(NAME, imp::SphericalQ::<1>); - benchmarks.register(NAME, imp::SphericalQ::<2>); - benchmarks.register(NAME, imp::SphericalQ::<4>); - benchmarks.register(NAME, imp::SphericalQ::<8>); +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + registry.register(NAME, imp::SphericalQ::<1>)?; + registry.register(NAME, imp::SphericalQ::<2>)?; + registry.register(NAME, imp::SphericalQ::<4>)?; + registry.register(NAME, imp::SphericalQ::<8>)?; + + Ok(()) } // Stub implementation #[cfg(not(feature = "spherical-quantization"))] -pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { - imp::register(NAME, benchmarks) +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { + imp::register(NAME, registry) } //////////////// diff --git a/diskann-benchmark/src/backend/filters/benchmark.rs b/diskann-benchmark/src/backend/filters/benchmark.rs index 7ce92420b..c5c4f3174 100644 --- a/diskann-benchmark/src/backend/filters/benchmark.rs +++ b/diskann-benchmark/src/backend/filters/benchmark.rs @@ -7,7 +7,7 @@ use anyhow::Result; use diskann_benchmark_runner::{ dispatcher::{FailureScore, MatchScore}, output::Output, - registry::Benchmarks, + registry::Registry, utils::{percentiles, MicroSeconds}, Benchmark, Checkpoint, }; @@ -28,8 +28,8 @@ use crate::{ utils::filters::QueryBitmapEvaluator, }; -pub(crate) fn register_benchmarks(benchmarks: &mut Benchmarks) { - benchmarks.register("metadata-index-build", MetadataIndexJob); +pub(crate) fn register_benchmarks(benchmarks: &mut Registry) -> anyhow::Result<()> { + Ok(benchmarks.register("metadata-index-build", MetadataIndexJob)?) } // Metadata-only index job. diff --git a/diskann-benchmark/src/backend/filters/mod.rs b/diskann-benchmark/src/backend/filters/mod.rs index ba8057672..127b7fcfe 100644 --- a/diskann-benchmark/src/backend/filters/mod.rs +++ b/diskann-benchmark/src/backend/filters/mod.rs @@ -6,6 +6,8 @@ mod benchmark; // Public registration function -pub(crate) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { - benchmark::register_benchmarks(benchmarks); +pub(crate) fn register_benchmarks( + registry: &mut diskann_benchmark_runner::registry::Registry, +) -> anyhow::Result<()> { + benchmark::register_benchmarks(registry) } diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 57aafc8eb..997d007b1 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -57,7 +57,9 @@ use crate::{ // Benchmark Registration // //////////////////////////// -pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { +pub(super) fn register_benchmarks( + registry: &mut diskann_benchmark_runner::registry::Registry, +) -> anyhow::Result<()> { // Notes on registration: // // We register all supported search types for `f32`, but intentionally limit the number @@ -70,49 +72,50 @@ pub(super) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::reg // care. // Full Precision - benchmarks.register( + registry.register( "graph-index-full-precision-f32", FullPrecision::::new() .search(plugins::Topk) .search(plugins::Range) .search(plugins::TopkBetaFilter) .search(plugins::TopkMultihopFilter), - ); + )?; - benchmarks.register( + registry.register( "graph-index-full-precision-f16", FullPrecision::::new().search(plugins::Topk), - ); - benchmarks.register( + )?; + registry.register( "graph-index-full-precision-u8", FullPrecision::::new().search(plugins::Topk), - ); - benchmarks.register( + )?; + registry.register( "graph-index-full-precision-i8", FullPrecision::::new().search(plugins::Topk), - ); + )?; // Dynamic Full Precision - benchmarks.register( + registry.register( "graph-index-dynamic-full-precision-f32", DynamicFullPrecision::::new(), - ); - benchmarks.register( + )?; + registry.register( "graph-index-dynamic-full-precision-f16", DynamicFullPrecision::::new(), - ); - benchmarks.register( + )?; + registry.register( "graph-index-dynamic-full-precision-u8", DynamicFullPrecision::::new(), - ); - benchmarks.register( + )?; + registry.register( "graph-index-dynamic-full-precision-i8", DynamicFullPrecision::::new(), - ); + )?; - product::register_benchmarks(benchmarks); - scalar::register_benchmarks(benchmarks); - spherical::register_benchmarks(benchmarks); + product::register_benchmarks(registry)?; + scalar::register_benchmarks(registry)?; + spherical::register_benchmarks(registry)?; + Ok(()) } type FullPrecisionProvider = inmem::DefaultProvider< diff --git a/diskann-benchmark/src/backend/index/mod.rs b/diskann-benchmark/src/backend/index/mod.rs index 269887c6d..b8261babb 100644 --- a/diskann-benchmark/src/backend/index/mod.rs +++ b/diskann-benchmark/src/backend/index/mod.rs @@ -15,6 +15,8 @@ mod product; mod scalar; mod spherical; -pub(crate) fn register_benchmarks(benchmarks: &mut diskann_benchmark_runner::registry::Benchmarks) { - benchmarks::register_benchmarks(benchmarks) +pub(crate) fn register_benchmarks( + registry: &mut diskann_benchmark_runner::registry::Registry, +) -> anyhow::Result<()> { + benchmarks::register_benchmarks(registry) } diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index f393ffd83..28fd8380e 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Benchmarks; +use diskann_benchmark_runner::registry::Registry; // Create a stub-module if the "spherical-quantization" feature is disabled. crate::utils::stub_impl!( @@ -11,7 +11,7 @@ crate::utils::stub_impl!( inputs::graph_index::IndexPQOperation ); -pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { #[cfg(feature = "product-quantization")] { use crate::backend::index::search::plugins; @@ -21,21 +21,23 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { // // Feel free to add search plugins, but be mindful of the monomorphization cost. - benchmarks.register( + registry.register( "graph-index-pq-f32", imp::ProductQuantized::::new() .search(plugins::Topk) .search(plugins::Range), - ); - benchmarks.register( + )?; + registry.register( "graph-index-pq-f16", imp::ProductQuantized::::new().search(plugins::Topk), - ); + )?; } // Stub implementation #[cfg(not(feature = "product-quantization"))] - imp::register("graph-index-pq", benchmarks); + imp::register("graph-index-pq", registry)?; + + Ok(()) } #[cfg(feature = "product-quantization")] diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index 79ce7e2de..2b8403474 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -3,18 +3,18 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Benchmarks; +use diskann_benchmark_runner::registry::Registry; // Create a stub-module if the "scalar-quantization" feature is disabled. crate::utils::stub_impl!("scalar-quantization", inputs::graph_index::IndexSQOperation); -pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { +pub(super) fn register_benchmarks(benchmarks: &mut Registry) -> anyhow::Result<()> { #[cfg(feature = "scalar-quantization")] { use crate::backend::index::search::plugins::Topk; // NOTE: This benchmark is heavily monomorphized. Each `(NBITS, T)` pair - // generates a full `Benchmark` impl via the `impl_sq_build!` macro in `mod imp`, + // generates a full `Registry` impl via the `impl_sq_build!` macro in `mod imp`, // which materially impacts compile time. We intentionally keep the registered // set minimal (`f32` at 1, 4, and 8 bits) to cover the common cases used by // `example/scalar.json`. @@ -32,20 +32,22 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { benchmarks.register( "graph-index-sq-8-bit-f32", imp::ScalarQuantized::<8, f32>::new().search(Topk), - ); + )?; benchmarks.register( "graph-index-sq-4-bit-f32", imp::ScalarQuantized::<4, f32>::new().search(Topk), - ); + )?; benchmarks.register( "graph-index-sq-1-bit-f32", imp::ScalarQuantized::<1, f32>::new().search(Topk), - ); + )?; } // Stub implementation #[cfg(not(feature = "scalar-quantization"))] - imp::register("graph-index-sq", benchmarks); + imp::register("graph-index-sq", benchmarks)?; + + Ok(()) } #[cfg(feature = "scalar-quantization")] diff --git a/diskann-benchmark/src/backend/index/search/plugins.rs b/diskann-benchmark/src/backend/index/search/plugins.rs index 43b8ba3e8..e0aca593f 100644 --- a/diskann-benchmark/src/backend/index/search/plugins.rs +++ b/diskann-benchmark/src/backend/index/search/plugins.rs @@ -20,7 +20,7 @@ //! //! Benchmarks own a [`Plugins`] collection and register only the plugin types they want to //! support. The helper methods on [`Plugins`] then integrate with -//! [`diskann_benchmark_runner::Benchmarks`]: +//! [`diskann_benchmark_runner::Registry`]: //! //! * [`Plugins::format_kinds`]: format the registered plugin labels for diagnostics. //! * [`Plugins::is_match`]: check whether any registered plugin accepts a requested `Kind`. diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index 20e9c0e29..b1bd9dd59 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Benchmarks; +use diskann_benchmark_runner::registry::Registry; // Create a stub-module if the "spherical-quantization" feature is disabled. crate::utils::stub_impl!( @@ -11,7 +11,7 @@ crate::utils::stub_impl!( inputs::graph_index::SphericalQuantBuild ); -pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { const NAME: &str = "graph-index-spherical-quantization"; #[cfg(feature = "spherical-quantization")] @@ -21,37 +21,39 @@ pub(super) fn register_benchmarks(benchmarks: &mut Benchmarks) { // NOTE: Since the spherical provider is not generic on the number of bits, the // implementations of the search-plugins are shared by all bit-widths. Registering // all plugins for all bit widths does not meaningfully increase compilation time. - benchmarks.register( + registry.register( NAME, imp::SphericalQ::<1>::new() .search(plugins::Topk) .search(plugins::Range) .search(plugins::TopkBetaFilter) .search(plugins::TopkMultihopFilter), - ); + )?; - benchmarks.register( + registry.register( NAME, imp::SphericalQ::<2>::new() .search(plugins::Topk) .search(plugins::Range) .search(plugins::TopkBetaFilter) .search(plugins::TopkMultihopFilter), - ); + )?; - benchmarks.register( + registry.register( NAME, imp::SphericalQ::<4>::new() .search(plugins::Topk) .search(plugins::Range) .search(plugins::TopkBetaFilter) .search(plugins::TopkMultihopFilter), - ); + )?; } // Stub implementation #[cfg(not(feature = "spherical-quantization"))] - imp::register(NAME, benchmarks) + imp::register(NAME, registry)?; + + Ok(()) } //////////////// diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index 24fe91d7e..7552e1c20 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -8,9 +8,12 @@ mod exhaustive; mod filters; mod index; -pub(crate) fn register_benchmarks(registry: &mut diskann_benchmark_runner::registry::Benchmarks) { - exhaustive::register_benchmarks(registry); - disk_index::register_benchmarks(registry); - index::register_benchmarks(registry); - filters::register_benchmarks(registry); +pub(crate) fn register_benchmarks( + registry: &mut diskann_benchmark_runner::registry::Registry, +) -> anyhow::Result<()> { + exhaustive::register_benchmarks(registry)?; + disk_index::register_benchmarks(registry)?; + index::register_benchmarks(registry)?; + filters::register_benchmarks(registry)?; + Ok(()) } diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 2951d1fe4..00f6067d4 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -25,13 +25,6 @@ use crate::{ as_input!(DiskIndexOperation); -pub(super) fn register_inputs( - registry: &mut diskann_benchmark_runner::registry::Inputs, -) -> anyhow::Result<()> { - registry.register::()?; - Ok(()) -} - /////////// // Input // /////////// diff --git a/diskann-benchmark/src/inputs/exhaustive.rs b/diskann-benchmark/src/inputs/exhaustive.rs index 4581956cb..d73bc1491 100644 --- a/diskann-benchmark/src/inputs/exhaustive.rs +++ b/diskann-benchmark/src/inputs/exhaustive.rs @@ -31,15 +31,6 @@ as_input!(Spherical); as_input!(Product); as_input!(MinMax); -pub(super) fn register_inputs( - registry: &mut diskann_benchmark_runner::registry::Inputs, -) -> anyhow::Result<()> { - registry.register::()?; - registry.register::()?; - registry.register::()?; - Ok(()) -} - //////////// // Search // //////////// diff --git a/diskann-benchmark/src/inputs/filters.rs b/diskann-benchmark/src/inputs/filters.rs index 981cef3ce..09fdaf919 100644 --- a/diskann-benchmark/src/inputs/filters.rs +++ b/diskann-benchmark/src/inputs/filters.rs @@ -14,13 +14,6 @@ use crate::inputs::{as_input, Example}; as_input!(MetadataIndexBuild); -pub(super) fn register_inputs( - registry: &mut diskann_benchmark_runner::registry::Inputs, -) -> anyhow::Result<()> { - registry.register::()?; - Ok(()) -} - /////////////////////////////// // Metadata-only Index Build // /////////////////////////////// diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 849b1a381..95cb89484 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -39,17 +39,6 @@ as_input!(IndexSQOperation); as_input!(SphericalQuantBuild); as_input!(DynamicIndexRun); -pub(super) fn register_inputs( - registry: &mut diskann_benchmark_runner::registry::Inputs, -) -> anyhow::Result<()> { - registry.register::()?; - registry.register::()?; - registry.register::()?; - registry.register::()?; - registry.register::()?; - Ok(()) -} - //////////// // Search // //////////// diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index 856412e2a..7875beb1d 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -9,16 +9,6 @@ pub(crate) mod filters; pub(crate) mod graph_index; pub(crate) mod save_and_load; -pub(crate) fn register_inputs( - registry: &mut diskann_benchmark_runner::registry::Inputs, -) -> anyhow::Result<()> { - graph_index::register_inputs(registry)?; - exhaustive::register_inputs(registry)?; - disk::register_inputs(registry)?; - filters::register_inputs(registry)?; - Ok(()) -} - /// Construct an example input of type `Self`. pub(crate) trait Example { fn example() -> Self; diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index a35d85427..2235b0dfa 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -42,15 +42,11 @@ impl Cli { fn run(&self, output: &mut dyn runner::Output) -> anyhow::Result<()> { self.check_target(output)?; - // Collect inputs. - let mut inputs = runner::registry::Inputs::new(); - inputs::register_inputs(&mut inputs)?; - // Collect benchmarks. - let mut benchmarks = runner::registry::Benchmarks::new(); - backend::register_benchmarks(&mut benchmarks); + let mut registry = runner::registry::Registry::new(); + backend::register_benchmarks(&mut registry)?; - self.app.run(&inputs, &benchmarks, output) + self.app.run(®istry, output) } #[cfg(test)] diff --git a/diskann-benchmark/src/utils/mod.rs b/diskann-benchmark/src/utils/mod.rs index 97eeae777..e4bf5cae7 100644 --- a/diskann-benchmark/src/utils/mod.rs +++ b/diskann-benchmark/src/utils/mod.rs @@ -103,14 +103,14 @@ macro_rules! stub_impl { use diskann_benchmark_runner::{ dispatcher::{FailureScore, MatchScore}, output::Output, - registry::Benchmarks, + registry::Registry, Benchmark, Checkpoint, }; use crate::inputs; - pub(super) fn register(name: &str, registry: &mut Benchmarks) { - registry.register(name, Stub); + pub(super) fn register(name: &str, registry: &mut Registry) -> anyhow::Result<()> { + Ok(registry.register(name, Stub)?) } /// An empty placeholder to provide a hint for the necessary feature. From 41bb2a943188aee27c284b0e25659e7e170051a9 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 12 May 2026 14:02:34 -0700 Subject: [PATCH 2/9] Cleanup. --- diskann-benchmark-runner/src/any.rs | 12 +-- diskann-benchmark-runner/src/app.rs | 2 +- diskann-benchmark-runner/src/jobs.rs | 2 +- diskann-benchmark-runner/src/lib.rs | 2 +- diskann-benchmark-runner/src/registry.rs | 91 ++++++++++++++++++- diskann-benchmark-simd/src/bin.rs | 6 +- diskann-benchmark-simd/src/lib.rs | 4 +- .../src/backend/disk_index/benchmarks.rs | 6 +- .../src/backend/disk_index/mod.rs | 2 +- .../src/backend/exhaustive/mod.rs | 2 +- .../src/backend/exhaustive/product.rs | 2 +- .../src/backend/exhaustive/spherical.rs | 2 +- .../src/backend/filters/benchmark.rs | 3 +- diskann-benchmark/src/backend/filters/mod.rs | 6 +- .../src/backend/index/benchmarks.rs | 6 +- diskann-benchmark/src/backend/index/mod.rs | 6 +- .../src/backend/index/product.rs | 2 +- diskann-benchmark/src/backend/index/scalar.rs | 2 +- .../src/backend/index/spherical.rs | 2 +- diskann-benchmark/src/backend/mod.rs | 6 +- diskann-benchmark/src/main.rs | 2 +- diskann-benchmark/src/utils/mod.rs | 3 +- 22 files changed, 119 insertions(+), 52 deletions(-) diff --git a/diskann-benchmark-runner/src/any.rs b/diskann-benchmark-runner/src/any.rs index 683614f6b..b06cfdad3 100644 --- a/diskann-benchmark-runner/src/any.rs +++ b/diskann-benchmark-runner/src/any.rs @@ -3,10 +3,7 @@ * Licensed under the MIT license. */ -use crate::{ - dispatcher::{DispatchRule, FailureScore, MatchScore}, - Input, -}; +use crate::dispatcher::{DispatchRule, FailureScore, MatchScore}; /// An refinement of [`std::any::Any`] with an associated name (tag) and serialization. /// @@ -36,13 +33,6 @@ impl Any { } } - pub fn input(any: T) -> Self - where - T: Input + serde::Serialize + std::fmt::Debug + 'static, - { - Self::new(any, T::tag()) - } - /// A lower level API for constructing an [`Any`] that decouples the serialized /// representation from the inmemory representation. /// diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index 249907858..ac40cd2d5 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -213,7 +213,7 @@ impl App { } writeln!(output, "Available input kinds are listed below:")?; - let mut tags: Vec<_> = registry.tags().collect(); + let mut tags: Vec<_> = registry.input_tags().collect(); tags.sort(); for i in tags.iter() { writeln!(output, " {}", i)?; diff --git a/diskann-benchmark-runner/src/jobs.rs b/diskann-benchmark-runner/src/jobs.rs index e0374e7f8..e7ca99a2a 100644 --- a/diskann-benchmark-runner/src/jobs.rs +++ b/diskann-benchmark-runner/src/jobs.rs @@ -8,7 +8,7 @@ use std::path::{Path, PathBuf}; use anyhow::Context; use serde::{Deserialize, Serialize}; -use crate::{checker::Checker, input, Registry, Any}; +use crate::{checker::Checker, input, Any, Registry}; #[derive(Debug)] pub(crate) struct Jobs { diff --git a/diskann-benchmark-runner/src/lib.rs b/diskann-benchmark-runner/src/lib.rs index a703d4534..4c2f3cd42 100644 --- a/diskann-benchmark-runner/src/lib.rs +++ b/diskann-benchmark-runner/src/lib.rs @@ -26,8 +26,8 @@ pub use benchmark::Benchmark; pub use checker::{CheckDeserialization, Checker}; pub use input::Input; pub use output::Output; +pub use registry::{Registry, RegistryError}; pub use result::Checkpoint; -pub use registry::Registry; #[cfg(any(test, feature = "test-app"))] pub mod test; diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index 8627eb30e..1cbd6a289 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -62,7 +62,7 @@ impl Registry { /// * [`register`](Self::register) /// * [`register_regression`](Self::register_regression) pub fn input(&self, tag: &str) -> Option> { - self.inputs.get(tag).map(|t| input::Registered(&**t)) + self._input(tag).map(input::Registered) } /// Return an iterator over all registered input tags in an unspecified order. @@ -70,7 +70,11 @@ impl Registry { self.inputs.keys().copied() } - /// Register a new benchmark with the given name. + /// Register a new `benchmark` with the given `name`. + /// + /// As a side-effect, the benchmark's [`Input`](Benchmark::Input) type is also registered. + /// Duplicate registrations of the same tag and type are allowed; mismatched types for the + /// same tag return an error. pub fn register( &mut self, name: impl Into, @@ -177,6 +181,10 @@ impl Registry { .map(|(entry, _)| entry) } + fn _input(&self, tag: &str) -> Option<&dyn input::DynInput> { + self.inputs.get(tag).map(|v| &**v) + } + fn register_input(&mut self) -> Result<(), RegistryError> where T: Input + 'static, @@ -208,7 +216,11 @@ impl Registry { // Regression Checks // //-------------------// - /// Register a regression-checkable benchmark with the associated name. + /// Register a regression-checkable `benchmark` with the given `name`. + /// + /// As a side-effect, the benchmark's [`Input`](Benchmark::Input) type is also registered. + /// Duplicate registrations of the same tag and type are allowed; mismatched types for the + /// same tag return an error. /// /// Upon registration, the associated [`Regression::Tolerances`] input and the benchmark /// itself will be reachable via [`Check`](crate::app::Check). @@ -286,7 +298,6 @@ pub struct RegistryError { new: &'static str, } - /// Document the reason for a method matching failure. pub struct Mismatch { method: String, @@ -363,3 +374,75 @@ impl std::fmt::Debug for Capture<'_> { self.0.description(f, self.1) } } + +/////////// +// Tests // +/////////// + +#[cfg(test)] +mod tests { + use super::*; + + use crate::{input, Checker}; + + macro_rules! input { + ($T:ident, $tag:literal) => { + struct $T; + + impl Input for $T { + fn tag() -> &'static str { + $tag + } + fn try_deserialize( + _serialized: &serde_json::Value, + _checker: &mut Checker, + ) -> anyhow::Result { + unimplemented!("this struct is for test only"); + } + fn example() -> anyhow::Result { + unimplemented!("this struct is for test only"); + } + } + }; + } + + // For the types below, `A` and `B` have distinct tags, but `A2`'s tag conflicts with `A2`. + input!(A, "type-a"); + input!(B, "type-b"); + input!(A2, "type-a"); + + #[test] + fn test_name_conflicts() { + let mut registry = Registry::new(); + registry.register_input::().unwrap(); + registry.register_input::().unwrap(); + + let mut tags: Vec<_> = registry.input_tags().collect(); + tags.sort(); + assert_eq!(tags.as_slice(), ["type-a", "type-b"]); + + { + let a = registry._input(A::tag()).unwrap(); + assert!(a.as_any().is::>()); + + let name = a.type_name(); + assert!(name.contains("A"), "{}", name); + } + + { + let b = registry._input(B::tag()).unwrap(); + assert!(b.as_any().is::>()); + + let name = b.type_name(); + assert!(name.contains("B"), "{}", name); + } + + let err = registry.register_input::().unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("A different input with tag \"type-a\" was already registered"), + "FAILED: {}", + msg + ); + } +} diff --git a/diskann-benchmark-simd/src/bin.rs b/diskann-benchmark-simd/src/bin.rs index 6ed64056d..217215912 100644 --- a/diskann-benchmark-simd/src/bin.rs +++ b/diskann-benchmark-simd/src/bin.rs @@ -3,8 +3,8 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::{output, registry, App, Output}; -use diskann_benchmark_simd::{register, SimdOp}; +use diskann_benchmark_runner::{output, App, Output, Registry}; +use diskann_benchmark_simd::register; pub fn main() -> anyhow::Result<()> { // Create the pocket bench application. @@ -13,7 +13,7 @@ pub fn main() -> anyhow::Result<()> { } fn main_inner(app: &App, output: &mut dyn Output) -> anyhow::Result<()> { - let mut registry = registry::Registry::new(); + let mut registry = Registry::new(); register(&mut registry)?; // Here we go! diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 830731df1..13bfb07b6 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -27,14 +27,14 @@ use diskann_benchmark_runner::{ num::{relative_change, NonNegativeFinite}, percentiles, MicroSeconds, }, - Any, Benchmark, CheckDeserialization, Checker, Input, + Any, Benchmark, CheckDeserialization, Checker, Input, Registry, }; //////////////// // Public API // //////////////// -pub fn register(registry: &mut diskann_benchmark_runner::registry::Registry) -> anyhow::Result<()> { +pub fn register(registry: &mut Registry) -> anyhow::Result<()> { Ok(register_benchmarks_impl(registry)?) } diff --git a/diskann-benchmark/src/backend/disk_index/benchmarks.rs b/diskann-benchmark/src/backend/disk_index/benchmarks.rs index 1b0a381c7..179fa479b 100644 --- a/diskann-benchmark/src/backend/disk_index/benchmarks.rs +++ b/diskann-benchmark/src/backend/disk_index/benchmarks.rs @@ -16,7 +16,7 @@ use diskann_benchmark_runner::{ fmt::Table, num::{relative_change, NonNegativeFinite}, }, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, + Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Registry, }; use diskann_providers::storage::FileStorageProvider; use half::f16; @@ -121,9 +121,7 @@ where // Benchmark Registration // //////////////////////////// -pub(super) fn register_benchmarks( - registry: &mut diskann_benchmark_runner::registry::Registry, -) -> anyhow::Result<()> { +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { registry.register_regression("disk-index-f32", DiskIndex::::new())?; registry.register_regression("disk-index-f16", DiskIndex::::new())?; registry.register_regression("disk-index-u8", DiskIndex::::new())?; diff --git a/diskann-benchmark/src/backend/disk_index/mod.rs b/diskann-benchmark/src/backend/disk_index/mod.rs index dac2b15b1..40177544b 100644 --- a/diskann-benchmark/src/backend/disk_index/mod.rs +++ b/diskann-benchmark/src/backend/disk_index/mod.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Registry; +use diskann_benchmark_runner::Registry; cfg_if::cfg_if! { if #[cfg(feature = "disk-index")] { diff --git a/diskann-benchmark/src/backend/exhaustive/mod.rs b/diskann-benchmark/src/backend/exhaustive/mod.rs index 0237262ff..41b4bc756 100644 --- a/diskann-benchmark/src/backend/exhaustive/mod.rs +++ b/diskann-benchmark/src/backend/exhaustive/mod.rs @@ -14,7 +14,7 @@ mod minmax; mod product; mod spherical; -use diskann_benchmark_runner::registry::Registry; +use diskann_benchmark_runner::Registry; pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { spherical::register_benchmarks(registry)?; diff --git a/diskann-benchmark/src/backend/exhaustive/product.rs b/diskann-benchmark/src/backend/exhaustive/product.rs index 5c41f2928..fc2117c07 100644 --- a/diskann-benchmark/src/backend/exhaustive/product.rs +++ b/diskann-benchmark/src/backend/exhaustive/product.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Registry; +use diskann_benchmark_runner::Registry; const NAME: &str = "product-exhaustive-search"; diff --git a/diskann-benchmark/src/backend/exhaustive/spherical.rs b/diskann-benchmark/src/backend/exhaustive/spherical.rs index 08fe66576..184721225 100644 --- a/diskann-benchmark/src/backend/exhaustive/spherical.rs +++ b/diskann-benchmark/src/backend/exhaustive/spherical.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Registry; +use diskann_benchmark_runner::Registry; const NAME: &str = "spherical-exhaustive-search"; diff --git a/diskann-benchmark/src/backend/filters/benchmark.rs b/diskann-benchmark/src/backend/filters/benchmark.rs index c5c4f3174..6ed9d9ba2 100644 --- a/diskann-benchmark/src/backend/filters/benchmark.rs +++ b/diskann-benchmark/src/backend/filters/benchmark.rs @@ -7,9 +7,8 @@ use anyhow::Result; use diskann_benchmark_runner::{ dispatcher::{FailureScore, MatchScore}, output::Output, - registry::Registry, utils::{percentiles, MicroSeconds}, - Benchmark, Checkpoint, + Benchmark, Checkpoint, Registry, }; use diskann_label_filter::{ kv_index::GenericIndex, diff --git a/diskann-benchmark/src/backend/filters/mod.rs b/diskann-benchmark/src/backend/filters/mod.rs index 127b7fcfe..d4ed433bd 100644 --- a/diskann-benchmark/src/backend/filters/mod.rs +++ b/diskann-benchmark/src/backend/filters/mod.rs @@ -3,11 +3,11 @@ * Licensed under the MIT license. */ +use diskann_benchmark_runner::Registry; + mod benchmark; // Public registration function -pub(crate) fn register_benchmarks( - registry: &mut diskann_benchmark_runner::registry::Registry, -) -> anyhow::Result<()> { +pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { benchmark::register_benchmarks(registry) } diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 997d007b1..bb1531df6 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -19,7 +19,7 @@ use diskann_benchmark_runner::{ dispatcher::{DispatchRule, FailureScore, MatchScore}, output::Output, utils::datatype, - Benchmark, Checkpoint, + Benchmark, Checkpoint, Registry, }; use diskann_providers::{ index::diskann_async, @@ -57,9 +57,7 @@ use crate::{ // Benchmark Registration // //////////////////////////// -pub(super) fn register_benchmarks( - registry: &mut diskann_benchmark_runner::registry::Registry, -) -> anyhow::Result<()> { +pub(super) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { // Notes on registration: // // We register all supported search types for `f32`, but intentionally limit the number diff --git a/diskann-benchmark/src/backend/index/mod.rs b/diskann-benchmark/src/backend/index/mod.rs index b8261babb..d459fc489 100644 --- a/diskann-benchmark/src/backend/index/mod.rs +++ b/diskann-benchmark/src/backend/index/mod.rs @@ -3,6 +3,8 @@ * Licensed under the MIT license. */ +use diskann_benchmark_runner::Registry; + mod build; mod search; mod streaming; @@ -15,8 +17,6 @@ mod product; mod scalar; mod spherical; -pub(crate) fn register_benchmarks( - registry: &mut diskann_benchmark_runner::registry::Registry, -) -> anyhow::Result<()> { +pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { benchmarks::register_benchmarks(registry) } diff --git a/diskann-benchmark/src/backend/index/product.rs b/diskann-benchmark/src/backend/index/product.rs index 28fd8380e..3d6f01207 100644 --- a/diskann-benchmark/src/backend/index/product.rs +++ b/diskann-benchmark/src/backend/index/product.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Registry; +use diskann_benchmark_runner::Registry; // Create a stub-module if the "spherical-quantization" feature is disabled. crate::utils::stub_impl!( diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index 2b8403474..ad9fd9738 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Registry; +use diskann_benchmark_runner::Registry; // Create a stub-module if the "scalar-quantization" feature is disabled. crate::utils::stub_impl!("scalar-quantization", inputs::graph_index::IndexSQOperation); diff --git a/diskann-benchmark/src/backend/index/spherical.rs b/diskann-benchmark/src/backend/index/spherical.rs index b1bd9dd59..71986020b 100644 --- a/diskann-benchmark/src/backend/index/spherical.rs +++ b/diskann-benchmark/src/backend/index/spherical.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Registry; +use diskann_benchmark_runner::Registry; // Create a stub-module if the "spherical-quantization" feature is disabled. crate::utils::stub_impl!( diff --git a/diskann-benchmark/src/backend/mod.rs b/diskann-benchmark/src/backend/mod.rs index 7552e1c20..8396577e8 100644 --- a/diskann-benchmark/src/backend/mod.rs +++ b/diskann-benchmark/src/backend/mod.rs @@ -3,14 +3,14 @@ * Licensed under the MIT license. */ +use diskann_benchmark_runner::Registry; + mod disk_index; mod exhaustive; mod filters; mod index; -pub(crate) fn register_benchmarks( - registry: &mut diskann_benchmark_runner::registry::Registry, -) -> anyhow::Result<()> { +pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> { exhaustive::register_benchmarks(registry)?; disk_index::register_benchmarks(registry)?; index::register_benchmarks(registry)?; diff --git a/diskann-benchmark/src/main.rs b/diskann-benchmark/src/main.rs index 2235b0dfa..cc70120cd 100644 --- a/diskann-benchmark/src/main.rs +++ b/diskann-benchmark/src/main.rs @@ -43,7 +43,7 @@ impl Cli { self.check_target(output)?; // Collect benchmarks. - let mut registry = runner::registry::Registry::new(); + let mut registry = runner::Registry::new(); backend::register_benchmarks(&mut registry)?; self.app.run(®istry, output) diff --git a/diskann-benchmark/src/utils/mod.rs b/diskann-benchmark/src/utils/mod.rs index e4bf5cae7..d7417c272 100644 --- a/diskann-benchmark/src/utils/mod.rs +++ b/diskann-benchmark/src/utils/mod.rs @@ -103,8 +103,7 @@ macro_rules! stub_impl { use diskann_benchmark_runner::{ dispatcher::{FailureScore, MatchScore}, output::Output, - registry::Registry, - Benchmark, Checkpoint, + Benchmark, Checkpoint, Registry, }; use crate::inputs; From 2f5b0b13ff3647f633ef2297d55b36af2d2a1140 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Tue, 12 May 2026 14:49:10 -0700 Subject: [PATCH 3/9] Cleanups. --- diskann-benchmark-runner/dev/main.rs | 6 +- diskann-benchmark-runner/src/benchmark.rs | 4 +- diskann-benchmark-runner/src/input.rs | 2 +- diskann-benchmark-simd/src/lib.rs | 4 +- diskann-benchmark/README.md | 181 ++++++++---------- .../src/backend/exhaustive/minmax.rs | 2 +- 6 files changed, 88 insertions(+), 111 deletions(-) diff --git a/diskann-benchmark-runner/dev/main.rs b/diskann-benchmark-runner/dev/main.rs index 9f7dae60d..d60b3f533 100644 --- a/diskann-benchmark-runner/dev/main.rs +++ b/diskann-benchmark-runner/dev/main.rs @@ -3,14 +3,14 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::{app::App, output, registry}; +use diskann_benchmark_runner::{output, App, Registry}; fn main() -> anyhow::Result<()> { // Parse the command line options. let app = App::parse(); - let mut registry = registry::Benchmarks::new(); - diskann_benchmark_runner::test::register_benchmarks(&mut registry); + let mut registry = Registry::new(); + diskann_benchmark_runner::test::register_benchmarks(&mut registry)?; app.run(®istry, &mut output::default()) } diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs index 27cb910a9..06c7bb150 100644 --- a/diskann-benchmark-runner/src/benchmark.rs +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -29,7 +29,7 @@ pub trait Benchmark: 'static { /// /// In the case of ties, the winner is chosen using an unspecified tie-breaking procedure. /// - /// On failure, returns `Err(FailureScore)`. In the [`crate::registry::Benchmarks`] + /// On failure, returns `Err(FailureScore)`. In the [`crate::Registry`] /// registry, [`FailureScore`]s will be used to rank the "nearest misses". Implementations /// are encouraged to generate ranked [`FailureScore`]s to assist in user level debugging. fn try_match(&self, input: &Self::Input) -> Result; @@ -68,7 +68,7 @@ pub trait Benchmark: 'static { /// The semantics of pass or failure are left solely to the discretion of the [`Regression`] /// implementation. /// -/// See: [`register_regression`](crate::registry::Benchmarks::register_regression). +/// See: [`register_regression`](crate::Registry::register_regression). pub trait Regression: Benchmark Deserialize<'a>> { /// The tolerance [`Input`] associated with this regression check. type Tolerances: Input + 'static; diff --git a/diskann-benchmark-runner/src/input.rs b/diskann-benchmark-runner/src/input.rs index 07e780b8c..019298542 100644 --- a/diskann-benchmark-runner/src/input.rs +++ b/diskann-benchmark-runner/src/input.rs @@ -40,7 +40,7 @@ pub trait Input { fn example() -> anyhow::Result; } -/// A registered input. See [`crate::registry::Inputs::get`]. +/// A registered input. See [`crate::Registry::input`]. #[derive(Clone, Copy)] pub struct Registered<'a>(pub(crate) &'a dyn DynInput); diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 13bfb07b6..b5ab3e503 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -303,8 +303,8 @@ impl std::fmt::Display for CheckResult { //////////////////////////// fn register_benchmarks_impl( - registry: &mut diskann_benchmark_runner::registry::Registry, -) -> Result<(), diskann_benchmark_runner::registry::RegistryError> { + registry: &mut diskann_benchmark_runner::Registry, +) -> Result<(), diskann_benchmark_runner::RegistryError> { // x86-64-v4 #[cfg(target_arch = "x86_64")] { diff --git a/diskann-benchmark/README.md b/diskann-benchmark/README.md index 7b48bca5d..d2ba46e62 100644 --- a/diskann-benchmark/README.md +++ b/diskann-benchmark/README.md @@ -300,27 +300,19 @@ can be run by `cargo test -p benchmark streaming`. ## Adding New Benchmarks -The benchmarking infrastructure uses a loosely-coupled method for dispatching benchmarks -broken into the front end (inputs) and the back end (benchmarks). Inputs can be any `serde` -compatible type. Input registration happens by registering types implementing -`diskann_benchmark_runner::Input` with the `diskann_benchmark_runner::registry::Inputs` -registry. This is done in `inputs::register_inputs`. At run time, the front end will discover -benchmarks in the input JSON file and use the tag string in the "contents" field to select -the correct input deserializer. - -Benchmarks need to be registered with `diskann_benchmark_runner::registry::Benchmarks` by -registering themselves in `benchmark::backend::load()`. To be discoverable by the front-end -input, a `DispatchRule` from the `dispatcher` crate (via -`diskann_benchmark_runner::dispatcher`) needs to be defined matching a back-end type to -`diskann_benchmark_runner::Any`. The dynamic type in the `Any` will come from one of the -registered `diskann_benchmark_runner::Inputs`. - -The rule can be as simple as checking a down cast or as complicated such as lifting run-time -information to the type/compile time realm, as is done for the graph index tests for the data -type. - -Once this is complete, the benchmark will be reachable by its input and can live peacefully -with the other benchmarks. +The benchmarking infrastructure works in two phases: first a raw JSON file is parsed into a +collection of registered `diskann_benchmark_runner::Input`s. Then, each input is matched +with a `diskann_benchmark_runner::Benchmark`. A `diskann_benchmark_runner::Registry` contains +the collection of all registered inputs and benchmarks. + +New benchmarks must implement the `diskann_benchmark_runner::Benchmark` trait, which has its +input as an associated type. Registering a benchmark via `Registry::register` will +automatically register the associated input. + +At run time, the front end will discover benchmarks in the input JSON file and use the tag +string in the "contents" field to select the correct input deserializer. Benchmarks will +be matched to inputs using `Benchmark::try_match`, with the best candidate being selected +to be run. ### Example @@ -366,10 +358,9 @@ impl diskann_benchmark_runner::Input for crate::inputs::Input anyhow::Result { @@ -389,26 +380,22 @@ impl diskann_benchmark_runner::Input for crate::inputs::Input Result<(), anyhow::Error> { - // Forward the deserializaiton check to the input files. - self.data.check_deserialization(checkt)?; - self.queries.check_deserialization(checkt)?; + // Forward the deserialization check to the input files. + self.data.check_deserialization(checker)?; + self.queries.check_deserialization(checker)?; Ok(()) } } ``` -#### Front End Registration +#### Benchmark Registration -With the new input type ready, we can register it with the -`diskann_benchmark_runner::registry::Inputs` registry. This can be as simple as: -```rust -fn register(registry: &mut diskann_benchmark_runner::registry::Inputs) -> anyhow::Result<()> { - registry.register(crate::inputs::Input::::new()) -} -``` -Note that registration can fail if multiple inputs have the same`tag`. +With the new input type ready, we register a benchmark that uses it with the +`diskann_benchmark_runner::Registry`. Input registration happens automatically as a +side-effect. Registration can fail if a different input type with the same `tag` was already +registered; duplicate registrations of the same tag and type are allowed. -When these steps are completed, our new input will be available using +When a benchmark is registered, the input will be available using ```sh cargo run --release --package diskann-benchmark -- inputs ``` @@ -418,104 +405,94 @@ cargo run --release --package diskann-benchmark -- inputs compute-groundtruth ``` will display an example JSON input for our type. -#### Back End Benchmarks - -So far, we have created a new input type and registered it with the front end, but there are -not any benchmarks that use this type. To implement new benchmarks, we need register them -with the `diskann_benchmark_runner::registry::Benchmarks` returned from -`benchmark::backend::load()`. The simplest thing we can do is something like this: +To implement benchmarks, we register them with the `diskann_benchmark_runner::Registry`. +The simplest thing we can do is something like this: ```rust use diskann_benchmark_runner::{ - dispatcher::{DispatchRule, MatchScore, FailureScore, Ref}, - Any, Checkpoint, Output + dispatcher::{MatchScore, FailureScore}, + Any, Benchmark, Checkpoint, Output, Registry, }; -// Allows the dispatcher to try to match a value with type `CentralDispatch` to the receiver -// type `ComputeGroundTruth`. -impl<'a> DispatchRule<&'a Any> for &'a ComputeGroundTruth { - type Error = anyhow::Error; +// Benchmarks can be stateful. +struct RunGroundTruth; - // Will return `Ok` if the dynamic type in `Any` matches - // - // Otherwise, returns a failure. - fn try_match(from: &&'a Any) -> Result { - from.try_match::(from) +impl Benchmark for RunGroundTruth { + // The input that will be registered along with the benchmark. + type Input = ComputeGroundTruth; + + // Real benchmarks should have output that will be saved. For this example, there + // is no meaningful output. + type Output = (); + + // Always match the input. + fn try_match(&self, input: &Self::Input) -> Result { + Ok(MatchScore::new(0)) } - // Will return `Ok` if `from`'s active variant is `ComputeGroundTruth`. - // - // This just forms a reference to the contained value and should only be called if - // `try_match` is successful. - fn convert(from: &'a Any) -> Result { - from.convert::(from) + // Run the benchmark (for this example, nothing happens). + fn run( + &self, + input: &Self::Input, + checkpoint: Checkpoint<'_>, + output: &mut dyn Output, + ) -> anyhow::Result { + Ok(()) } } -fn register(benchmarks: &mut diskann_benchmark_runner::registry::Bencymarks) { - benchmarks.register::>( - "compute-groundtruth", - |input: &ComputeGroundTruth, checkpoint: Checkpoint<'_>, output: &mut dyn Output| { - // Run the benchmark - } - ) +fn register(registry: &mut diskann_benchmark_runner::Registry) -> anyhow::Result<()> { + // Register the benchmark and its associated input. + Ok(registry.register("compute-groundtruth", RunGroundTruth)?) } ``` -What happening here is that the implementation of `DispatchRule` provides a valid conversion -from `&Any` to `&ComputeGroundTruth`, which is only applicable if runtime value in the `Any` -is the `ComputeGroundTruth` struct. If this happens, the benchmarking infrastructure will -call the closure passed to `benchmarks.register()` after calling `DispatchRule::convert()` -on the `Any`. This mechanism allows multiple backend benchmarks to exist and pull input from -the deserialized inputs present in the current run. - -There are three more things to note about closures (benchmarks) that get registered with the dispatcher: - -1. The argument `checkpoint: diskann_benchmark_runner::Checkpoint<'_>` allows long-running - benchmarks to periodically save incremental results to file by calling the `.checkpoint` - method. Benchmark results are anything that implements `serde::Serialize`. This function - creates a new snapshot every time it is invoked, so benchmarks to not need to worry about - redundant data. +What is happening here is that the implementation of `Benchmark::try_match` checks if the +benchmark matches the runtime parameters in the associated input. For the case of the example, +this always succeeds. If the `try_match` is successful, then the benchmarking infrastructure +will call `Benchmark::run`. This mechanism allows multiple backend benchmarks to exist and +pull input from the deserialized inputs present in the current run. If multiple benchmarks +match an input, then the benchmark with the lowest `MatchScore` will be selected. -2. The argument `output: &mut dyn diskann_benchmark_runner::Output` is a dynamic type where - all output should be written too. Additionally, it provides a - [`ProgressDrawTarget`](https://docs.rs/indicatif/latest/indicatif/struct.ProgressDrawTarget.html) - for use with [indicatif](https://docs.rs/indicatif/latest/indicatif/index.html) progress bars. - This supports output redirection for integration tests and piping to files. +The argument `checkpoint: diskann_benchmark_runner::Checkpoint<'_>` allows long-running +benchmarks to periodically save incremental results to file by calling the `.checkpoint` +method. This function creates a new snapshot every time it is invoked, so benchmarks do not +need to worry about redundant data. -3. The return type from the closure should be `anyhow::Result`. This - contains all data collected from the benchmark and will be collected and saved along with - all other runs. Benchmark implementations do not need to worry about saving their input - as well as this is automatically handled by the benchmarking infrastructure. +The argument `output: &mut dyn diskann_benchmark_runner::Output` is a dynamic type where +all output should be written too. Additionally, it provides a +[`ProgressDrawTarget`](https://docs.rs/indicatif/latest/indicatif/struct.ProgressDrawTarget.html) +for use with [indicatif](https://docs.rs/indicatif/latest/indicatif/index.html) progress bars. +This supports output redirection for integration tests and piping to files. With the benchmark registered, that is all that is needed. -#### Expanding `DispatchRule` +#### Matching with `try_match` -The functionality offered by `DispatchRule` is much more powerful than what was described in -the simple example. In particular, careful implementation will allow your benchmarks to be -more easily discoverable from the command-line and can also assist in debugging by providing -"near misses". +The functionality offered by `Benchmark::try_match` is much more powerful than what was +described in the simple example. In particular, careful implementation will allow your +benchmarks to be more easily discoverable from the command-line and can also assist in +debugging by providing "near misses". **Fine Grained Matching** -The method `DispatchRule::try_match` returns both a successful `MatchScore` and an -unsuccessful `FailureScore`. The dispatcher will only invoke methods where all arguments +The method `Benchmark::try_match` returns both a successful `MatchScore` and an +unsuccessful `FailureScore`. The registry will only invoke methods where all arguments return successful `MatchScores`. Additionally, it will call the method with the "best" -overall score, determined by lexicographic ordering. So, you can make some registered -benchmarks "better fits" for inputs returning a better match score. +overall score. So, you can make some registered benchmarks "better fits" for inputs +returning a better match score. -When the dispatcher cannot find any matching method for an input, it begins a process of +When the registry cannot find any matching method for an input, it begins a process of finding the "nearest misses" by inspecting and ranking methods based on their `FailureScore`. Benchmarks can opt-in to this process by returning meaning `FailureScores` when an input is close, but not quite right. **Benchmark Description and Failure Description** -The trait `DispatchRule` has another method: +The trait `Benchmark` has another method: ```rust -fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&&'a Any>); +fn description(f: &mut std::fmt::Formatter<'_>, from: Option<&Self::Input>); ``` -This is used for self-documenting the dispatch rule: If `from` is `None`, then +This is used for self-documenting the matching rule: If `from` is `None`, then implementations should write to the formatter `f` a description of the benchmark and what inputs it can work with. If `from` is `Some`, then implementation should write the reason for a successful or unsuccessful match with the enclosed value. Doing these two steps make diff --git a/diskann-benchmark/src/backend/exhaustive/minmax.rs b/diskann-benchmark/src/backend/exhaustive/minmax.rs index ecd1f1eb7..084da0fb7 100644 --- a/diskann-benchmark/src/backend/exhaustive/minmax.rs +++ b/diskann-benchmark/src/backend/exhaustive/minmax.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::registry::Registry; +use diskann_benchmark_runner::Registry; const NAME: &str = "minmax-exhaustive-search"; From 1509cc51863e0101adc3ab685b5e82ac985410c3 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Thu, 14 May 2026 11:10:35 -0700 Subject: [PATCH 4/9] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-benchmark-runner/src/app.rs | 4 ++-- diskann-benchmark-runner/src/registry.rs | 6 +++--- diskann-benchmark/README.md | 6 +++--- diskann-benchmark/src/backend/index/scalar.rs | 3 ++- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index ac40cd2d5..b46e4fb78 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -39,8 +39,8 @@ //! //! fn main() -> anyhow::Result<()> { //! let mut registry = registry::Registry::new(); -//! // registry.register::("my-bench"); -//! // registry.register_regression::("my-regression"); +//! // registry.register("my-bench", MyBenchmark::default())?; +//! // registry.register_regression("my-regression", MyRegressionBenchmark::default())?; //! //! let app = App::parse(); //! let mut output = diskann_benchmark_runner::output::default(); diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index 1cbd6a289..4467c4e09 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -39,7 +39,7 @@ impl RegisteredBenchmark { } } -/// A collection of registered benchmarks. +/// A collection of registered inputs and benchmarks. pub struct Registry { // Inputs keyed by their tag type. inputs: HashMap<&'static str, Box>, @@ -406,13 +406,13 @@ mod tests { }; } - // For the types below, `A` and `B` have distinct tags, but `A2`'s tag conflicts with `A2`. + // For the types below, `A` and `B` have distinct tags, but `A2`'s tag conflicts with `A`. input!(A, "type-a"); input!(B, "type-b"); input!(A2, "type-a"); #[test] - fn test_name_conflicts() { + fn test_tag_conflicts() { let mut registry = Registry::new(); registry.register_input::().unwrap(); registry.register_input::().unwrap(); diff --git a/diskann-benchmark/README.md b/diskann-benchmark/README.md index d2ba46e62..923bb9e27 100644 --- a/diskann-benchmark/README.md +++ b/diskann-benchmark/README.md @@ -310,7 +310,7 @@ input as an associated type. Registering a benchmark via `Registry::register` wi automatically register the associated input. At run time, the front end will discover benchmarks in the input JSON file and use the tag -string in the "contents" field to select the correct input deserializer. Benchmarks will +string in the `type` field to select the correct input deserializer. Benchmarks will be matched to inputs using `Benchmark::try_match`, with the best candidate being selected to be run. @@ -459,7 +459,7 @@ method. This function creates a new snapshot every time it is invoked, so benchm need to worry about redundant data. The argument `output: &mut dyn diskann_benchmark_runner::Output` is a dynamic type where -all output should be written too. Additionally, it provides a +all output should be written to. Additionally, it provides a [`ProgressDrawTarget`](https://docs.rs/indicatif/latest/indicatif/struct.ProgressDrawTarget.html) for use with [indicatif](https://docs.rs/indicatif/latest/indicatif/index.html) progress bars. This supports output redirection for integration tests and piping to files. @@ -483,7 +483,7 @@ returning a better match score. When the registry cannot find any matching method for an input, it begins a process of finding the "nearest misses" by inspecting and ranking methods based on their `FailureScore`. -Benchmarks can opt-in to this process by returning meaning `FailureScores` when an input is +Benchmarks can opt-in to this process by returning meaningful `FailureScores` when an input is close, but not quite right. **Benchmark Description and Failure Description** diff --git a/diskann-benchmark/src/backend/index/scalar.rs b/diskann-benchmark/src/backend/index/scalar.rs index ad9fd9738..6dd01f69a 100644 --- a/diskann-benchmark/src/backend/index/scalar.rs +++ b/diskann-benchmark/src/backend/index/scalar.rs @@ -14,7 +14,8 @@ pub(super) fn register_benchmarks(benchmarks: &mut Registry) -> anyhow::Result<( use crate::backend::index::search::plugins::Topk; // NOTE: This benchmark is heavily monomorphized. Each `(NBITS, T)` pair - // generates a full `Registry` impl via the `impl_sq_build!` macro in `mod imp`, + // generates a full `Benchmark` impl/build path for + // `ScalarQuantized` via the `impl_sq_build!` macro in `mod imp`, // which materially impacts compile time. We intentionally keep the registered // set minimal (`f32` at 1, 4, and 8 bits) to cover the common cases used by // `example/scalar.json`. From dc35131b548b7e9f1fc83f1cba5264d5bb163f57 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 11:32:47 -0700 Subject: [PATCH 5/9] Simplify Any. --- diskann-benchmark-runner/src/any.rs | 196 ------------------ diskann-benchmark-runner/src/app.rs | 6 +- diskann-benchmark-runner/src/benchmark.rs | 2 +- diskann-benchmark-runner/src/checker.rs | 35 +--- diskann-benchmark-runner/src/files.rs | 36 ++-- diskann-benchmark-runner/src/input.rs | 83 +++++++- .../src/internal/regression.rs | 5 +- diskann-benchmark-runner/src/jobs.rs | 12 +- diskann-benchmark-runner/src/lib.rs | 4 +- diskann-benchmark-runner/src/registry.rs | 35 ++-- diskann-benchmark-runner/src/result.rs | 6 +- diskann-benchmark-runner/src/test/dim.rs | 41 ++-- diskann-benchmark-runner/src/test/typed.rs | 98 ++++----- .../benchmark/test-mismatch-0/stdout.txt | 3 +- .../benchmark/test-mismatch-1/stdout.txt | 3 +- .../benchmark/test-overload-0/output.json | 6 +- .../benchmark/test-success-0/output.json | 6 +- .../regression/check-run-error-0/output.json | 3 +- .../regression/check-run-error-2/output.json | 3 +- .../regression/check-run-fail-0/output.json | 3 +- .../regression/check-run-pass-0/output.json | 6 +- diskann-benchmark-simd/src/lib.rs | 52 ++--- .../src/backend/disk_index/benchmarks.rs | 27 ++- diskann-benchmark/src/inputs/disk.rs | 47 ++--- diskann-benchmark/src/inputs/exhaustive.rs | 49 ++--- diskann-benchmark/src/inputs/filters.rs | 15 +- diskann-benchmark/src/inputs/graph_index.rs | 155 ++++++-------- diskann-benchmark/src/inputs/mod.rs | 26 +-- 28 files changed, 349 insertions(+), 614 deletions(-) delete mode 100644 diskann-benchmark-runner/src/any.rs diff --git a/diskann-benchmark-runner/src/any.rs b/diskann-benchmark-runner/src/any.rs deleted file mode 100644 index d25a58a9e..000000000 --- a/diskann-benchmark-runner/src/any.rs +++ /dev/null @@ -1,196 +0,0 @@ -/* - * Copyright (c) Microsoft Corporation. - * Licensed under the MIT license. - */ - -/// An refinement of [`std::any::Any`] with an associated name (tag) and serialization. -/// -/// This type represents deserialized inputs returned from [`crate::Input::try_deserialize`] -/// and is passed to beckend benchmarks for matching and execution. -#[derive(Debug)] -pub struct Any { - any: Box, - tag: &'static str, -} - -impl Any { - /// Construct a new [`Any`] around `any` and associate it with the name `tag`. - /// - /// The tag is included as merely a debugging and readability aid and usually should - /// belong to a [`crate::Input::tag`] that generated `any`. - pub fn new(any: T, tag: &'static str) -> Self - where - T: serde::Serialize + std::fmt::Debug + 'static, - { - Self { - any: Box::new(any), - tag, - } - } - - /// A lower level API for constructing an [`Any`] that decouples the serialized - /// representation from the inmemory representation. - /// - /// When serialized, the **exact** representation of `repr` will be used. - /// - /// This is useful in some contexts where as part of input resolution, a fully resolved - /// input struct contains elements that are not serializable. - /// - /// Like [`Any::new`], the tag is included for debugging and readability. - pub fn raw(any: T, repr: serde_json::Value, tag: &'static str) -> Self - where - T: std::fmt::Debug + 'static, - { - Self { - any: Box::new(Raw::new(any, repr)), - tag, - } - } - - /// Return the benchmark tag associated with this benchmarks. - pub fn tag(&self) -> &'static str { - self.tag - } - - /// Return the Rust [`std::any::TypeId`] for the contained object. - pub fn type_id(&self) -> std::any::TypeId { - self.any.as_any().type_id() - } - - /// Return `true` if the runtime value is `T`. Otherwise, return false. - /// - /// ```rust - /// use diskann_benchmark_runner::any::Any; - /// - /// let value = Any::new(42usize, "usize"); - /// assert!(value.is::()); - /// assert!(!value.is::()); - /// ``` - #[must_use = "this function has no side effects"] - pub fn is(&self) -> bool - where - T: std::any::Any, - { - self.any.as_any().is::() - } - - /// Return a reference to the contained object if it's runtime type is `T`. - /// - /// Otherwise return `None`. - /// - /// ```rust - /// use diskann_benchmark_runner::any::Any; - /// - /// let value = Any::new(42usize, "usize"); - /// assert_eq!(*value.downcast_ref::().unwrap(), 42); - /// assert!(value.downcast_ref::().is_none()); - /// ``` - pub fn downcast_ref(&self) -> Option<&T> - where - T: std::any::Any, - { - self.any.as_any().downcast_ref::() - } - - /// Serialize the contained object to a [`serde_json::Value`]. - pub fn serialize(&self) -> Result { - self.any.dump() - } -} - -trait SerializableAny: std::fmt::Debug { - fn as_any(&self) -> &dyn std::any::Any; - fn dump(&self) -> Result; -} - -impl SerializableAny for T -where - T: std::any::Any + serde::Serialize + std::fmt::Debug, -{ - fn as_any(&self) -> &dyn std::any::Any { - self - } - - fn dump(&self) -> Result { - serde_json::to_value(self) - } -} - -// A backend type that allows users to decouple the serialized representation from the -// actual type. -#[derive(Debug)] -struct Raw { - value: T, - repr: serde_json::Value, -} - -impl Raw { - fn new(value: T, repr: serde_json::Value) -> Self { - Self { value, repr } - } -} - -impl SerializableAny for Raw -where - T: std::any::Any + std::fmt::Debug, -{ - fn as_any(&self) -> &dyn std::any::Any { - &self.value - } - - fn dump(&self) -> Result { - Ok(self.repr.clone()) - } -} - -/////////// -// Tests // -/////////// - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_new() { - let x = Any::new(42usize, "my-tag"); - assert_eq!(x.tag(), "my-tag"); - assert_eq!(x.type_id(), std::any::TypeId::of::()); - assert!(x.is::()); - assert!(!x.is::()); - assert_eq!(*x.downcast_ref::().unwrap(), 42); - assert!(x.downcast_ref::().is_none()); - - assert!(!x.is::>()); - assert!(!x.is::>()); - assert!(x.downcast_ref::>().is_none()); - assert!(x.downcast_ref::>().is_none()); - - assert_eq!( - x.serialize().unwrap(), - serde_json::Value::Number(serde_json::value::Number::from(42usize)) - ); - } - - #[test] - fn test_raw() { - let repr = serde_json::json!(1.5); - let x = Any::raw(42usize, repr, "my-tag"); - assert_eq!(x.tag(), "my-tag"); - assert_eq!(x.type_id(), std::any::TypeId::of::()); - assert!(x.is::()); - assert!(!x.is::()); - assert_eq!(*x.downcast_ref::().unwrap(), 42); - assert!(x.downcast_ref::().is_none()); - - assert!(!x.is::>()); - assert!(!x.is::>()); - assert!(x.downcast_ref::>().is_none()); - assert!(x.downcast_ref::>().is_none()); - - assert_eq!( - x.serialize().unwrap(), - serde_json::Value::Number(serde_json::value::Number::from_f64(1.5).unwrap()) - ); - } -} diff --git a/diskann-benchmark-runner/src/app.rs b/diskann-benchmark-runner/src/app.rs index b46e4fb78..42f146ca5 100644 --- a/diskann-benchmark-runner/src/app.rs +++ b/diskann-benchmark-runner/src/app.rs @@ -293,12 +293,12 @@ impl App { let serialized = jobs .iter() .map(|job| { - serde_json::to_value(jobs::Unprocessed::new( + Ok(serde_json::to_value(jobs::Unprocessed::new( job.tag().into(), job.serialize()?, - )) + ))?) }) - .collect::, serde_json::Error>>()?; + .collect::>>()?; for (i, job) in jobs.iter().enumerate() { let prefix: &str = if i != 0 { "\n\n" } else { "" }; writeln!( diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs index 39f48fd54..1130fefaa 100644 --- a/diskann-benchmark-runner/src/benchmark.rs +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; -use crate::{Any, Checkpoint, Input, Output}; +use crate::{input::Any, Checkpoint, Input, Output}; /// A registered benchmark. /// diff --git a/diskann-benchmark-runner/src/checker.rs b/diskann-benchmark-runner/src/checker.rs index 4b3dda556..da67fb113 100644 --- a/diskann-benchmark-runner/src/checker.rs +++ b/diskann-benchmark-runner/src/checker.rs @@ -8,8 +8,6 @@ use std::{ path::{Path, PathBuf}, }; -use crate::Any; - /// Shared context for resolving input and output files paths post deserialization. #[derive(Debug)] pub struct Checker { @@ -29,12 +27,6 @@ pub struct Checker { /// /// This ensures that each job uses a distinct output directory to avoid conflicts. current_outputs: HashSet, - - /// This crate-private variable is used to store the current input deserialization - /// tag and is referenced when creating new `Any` objects. - /// - /// Ensure that the correct tag is present before invoking [`Input::try_deserialize`]. - tag: Option<&'static str>, } impl Checker { @@ -44,23 +36,9 @@ impl Checker { search_directories, output_directory, current_outputs: HashSet::new(), - tag: None, } } - /// Invoke [`CheckDeserialization`] on `value` and if successful, package it in [`Any`]. - pub fn any(&mut self, mut value: T) -> anyhow::Result - where - T: serde::Serialize + CheckDeserialization + std::fmt::Debug + 'static, - { - value.check_deserialization(self)?; - #[expect( - clippy::expect_used, - reason = "crate infrastructure ensures an untagged Checker is not leaked" - )] - Ok(Any::new(value, self.tag.expect("tag must be set"))) - } - /// Return the ordered list of search directories registered with the [`Checker`]. pub fn search_directories(&self) -> &[PathBuf] { &self.search_directories @@ -168,16 +146,9 @@ impl Checker { ))) } - pub(crate) fn set_tag(&mut self, tag: &'static str) { - let _ = self.tag.insert(tag); - } -} - -/// Perform post-process resolution of input and output files paths. -pub trait CheckDeserialization { - /// Perform any necessary resolution of file paths, returning an error if a problem is - /// discovered. - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error>; + // pub(crate) fn set_tag(&mut self, tag: &'static str) { + // let _ = self.tag.insert(tag); + // } } /////////// diff --git a/diskann-benchmark-runner/src/files.rs b/diskann-benchmark-runner/src/files.rs index 355b47010..1672f6f62 100644 --- a/diskann-benchmark-runner/src/files.rs +++ b/diskann-benchmark-runner/src/files.rs @@ -7,7 +7,7 @@ use std::path::{Path, PathBuf}; use serde::{Deserialize, Serialize}; -use super::checker::{CheckDeserialization, Checker}; +use super::Checker; /// A file that is used as an input to for a benchmark. /// @@ -37,17 +37,8 @@ impl InputFile { path: PathBuf::from(path), } } -} - -impl std::ops::Deref for InputFile { - type Target = Path; - fn deref(&self) -> &Self::Target { - &self.path - } -} -impl CheckDeserialization for InputFile { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + pub fn resolve(&mut self, checker: &mut Checker) -> anyhow::Result<()> { let checked_path = checker.check_path(self); match checked_path { Ok(p) => { @@ -59,6 +50,13 @@ impl CheckDeserialization for InputFile { } } +impl std::ops::Deref for InputFile { + type Target = Path; + fn deref(&self) -> &Self::Target { + &self.path + } +} + /////////// // Tests // /////////// @@ -86,7 +84,7 @@ mod tests { } #[test] - fn test_check_deserialization() { + fn test_resolve() { // We create a directory that looks like this: // // dir/ @@ -113,13 +111,13 @@ mod tests { let absolute = path.join("file_a.txt"); let mut file = InputFile::new(absolute.clone()); let mut checker = Checker::new(Vec::new(), None); - file.check_deserialization(&mut checker).unwrap(); + file.resolve(&mut checker).unwrap(); assert_eq!(file.path, absolute); let absolute = path.join("dir0/file_b.txt"); let mut file = InputFile::new(absolute.clone()); let mut checker = Checker::new(Vec::new(), None); - file.check_deserialization(&mut checker).unwrap(); + file.resolve(&mut checker).unwrap(); assert_eq!(file.path, absolute); } @@ -128,7 +126,7 @@ mod tests { let absolute = path.join("dir0/file_c.txt"); let mut file = InputFile::new(absolute.clone()); let mut checker = Checker::new(Vec::new(), None); - let err = file.check_deserialization(&mut checker).unwrap_err(); + let err = file.resolve(&mut checker).unwrap_err(); let message = err.to_string(); assert!(message.contains("input file with absolute path")); assert!(message.contains("either does not exist or is not a file")); @@ -143,23 +141,23 @@ mod tests { // Directories are searched in order. let mut file = InputFile::new("file_c.txt"); - file.check_deserialization(&mut checker).unwrap(); + file.resolve(&mut checker).unwrap(); assert_eq!(file.path, path.join("dir1/dir0/file_c.txt")); let mut file = InputFile::new("file_b.txt"); - file.check_deserialization(&mut checker).unwrap(); + file.resolve(&mut checker).unwrap(); assert_eq!(file.path, path.join("dir0/file_b.txt")); // Directory search can fail. let mut file = InputFile::new("file_a.txt"); - let err = file.check_deserialization(&mut checker).unwrap_err(); + let err = file.resolve(&mut checker).unwrap_err(); let message = err.to_string(); assert!(message.contains("could not find input file")); assert!(message.contains("in the search directories")); // If we give an absolute path, no directory search is performed. let mut file = InputFile::new(path.join("file_c.txt")); - let err = file.check_deserialization(&mut checker).unwrap_err(); + let err = file.resolve(&mut checker).unwrap_err(); let message = err.to_string(); assert!(message.starts_with("input file with absolute path")); } diff --git a/diskann-benchmark-runner/src/input.rs b/diskann-benchmark-runner/src/input.rs index 019298542..203cd6f1e 100644 --- a/diskann-benchmark-runner/src/input.rs +++ b/diskann-benchmark-runner/src/input.rs @@ -3,9 +3,11 @@ * Licensed under the MIT license. */ -use crate::{Any, Checker}; +use crate::Checker; + +pub trait Input: Sized + std::fmt::Debug + 'static { + type Raw: serde::de::DeserializeOwned + serde::Serialize; -pub trait Input { /// Return the discriminant associated with this type. /// /// This is used to map inputs types to their respective parsers. @@ -26,10 +28,10 @@ pub trait Input { /// [`CheckDeserialization`](crate::CheckDeserialization) and use this API to ensure /// shared resources (like input files or output files) are correctly resolved and /// properly shared among all jobs in a benchmark run. - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result; + fn from_raw(raw: Self::Raw, checker: &mut Checker) -> anyhow::Result; + + /// Serialize `self` to a [`serde_json::Value`]. + fn serialize(&self) -> anyhow::Result; /// Print an example JSON representation of objects this input is expected to parse. /// @@ -37,7 +39,7 @@ pub trait Input { /// [`serde_json::Value`] back to [`Self::try_deserialize`] correctly deserializes, /// though it need not necessarily pass /// [`CheckDeserialization`](crate::CheckDeserialization). - fn example() -> anyhow::Result; + fn example() -> Self::Raw; } /// A registered input. See [`crate::Registry::input`]. @@ -55,7 +57,7 @@ impl Registered<'_> { /// Try to deserialize raw JSON into the dynamic type of the input. /// /// See: [`Input::try_deserialize`]. - pub fn try_deserialize( + pub(crate) fn try_deserialize( &self, serialized: &serde_json::Value, checker: &mut Checker, @@ -83,6 +85,64 @@ impl std::fmt::Debug for Registered<'_> { // Internal // ////////////// +/// Runtime representation of a deserialized [`Input`]. +#[derive(Debug)] +pub(crate) struct Any { + any: Box, +} + +impl Any { + pub(crate) fn new(input: T) -> Self + where + T: Input + std::fmt::Debug + 'static, + { + Self { + any: Box::new(input), + } + } + + #[must_use = "this function has no side effects"] + pub(crate) fn tag(&self) -> &'static str { + self.any.tag() + } + + #[must_use = "this function has no side effects"] + pub(crate) fn downcast_ref(&self) -> Option<&T> + where + T: std::any::Any, + { + self.any.as_any().downcast_ref::() + } + + #[must_use = "this function has no side effects"] + pub(crate) fn serialize(&self) -> anyhow::Result { + self.any.serialize() + } +} + +trait RuntimeAny: std::fmt::Debug { + fn tag(&self) -> &'static str; + fn as_any(&self) -> &dyn std::any::Any; + fn serialize(&self) -> anyhow::Result; +} + +impl RuntimeAny for T +where + T: Input + std::fmt::Debug + 'static, +{ + fn tag(&self) -> &'static str { + ::tag() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn serialize(&self) -> anyhow::Result { + ::serialize(self) + } +} + #[derive(Debug)] pub(crate) struct Wrapper(std::marker::PhantomData); @@ -118,7 +178,7 @@ pub(crate) trait DynInput { impl DynInput for Wrapper where - T: Input + 'static, + T: Input + std::fmt::Debug + 'static, { fn tag(&self) -> &'static str { T::tag() @@ -128,10 +188,11 @@ where serialized: &serde_json::Value, checker: &mut Checker, ) -> anyhow::Result { - T::try_deserialize(serialized, checker) + let raw = >::deserialize(serialized)?; + Ok(Any::new(T::from_raw(raw, checker)?)) } fn example(&self) -> anyhow::Result { - T::example() + Ok(serde_json::to_value(T::example())?) } fn as_any(&self) -> &dyn std::any::Any { self diff --git a/diskann-benchmark-runner/src/internal/regression.rs b/diskann-benchmark-runner/src/internal/regression.rs index c1d1838ee..eb7289cc7 100644 --- a/diskann-benchmark-runner/src/internal/regression.rs +++ b/diskann-benchmark-runner/src/internal/regression.rs @@ -100,8 +100,9 @@ use serde_json::Value; use crate::{ benchmark::{internal::CheckedPassFail, PassFail}, + input::Any, internal::load_from_disk, - jobs, registry, result, Any, Checker, + jobs, registry, result, Checker, }; //////////// @@ -348,7 +349,7 @@ impl Raw { .with_context(context); } - checker.set_tag(entry.tolerance.tag()); + // checker.set_tag(entry.tolerance.tag()); let tolerance = entry .tolerance .try_deserialize(&unprocessed.tolerance.content, &mut checker) diff --git a/diskann-benchmark-runner/src/jobs.rs b/diskann-benchmark-runner/src/jobs.rs index e7ca99a2a..9c43acbef 100644 --- a/diskann-benchmark-runner/src/jobs.rs +++ b/diskann-benchmark-runner/src/jobs.rs @@ -8,22 +8,22 @@ use std::path::{Path, PathBuf}; use anyhow::Context; use serde::{Deserialize, Serialize}; -use crate::{checker::Checker, input, Any, Registry}; +use crate::{checker::Checker, input, Registry}; #[derive(Debug)] pub(crate) struct Jobs { /// The benchmark jobs to execute. - jobs: Vec, + jobs: Vec, } impl Jobs { /// Return the jobs associated with this benchmark run. - pub(crate) fn jobs(&self) -> &[Any] { + pub(crate) fn jobs(&self) -> &[input::Any] { &self.jobs } /// Consume `self`, returning the contained list of jobs. - pub(crate) fn into_inner(self) -> Vec { + pub(crate) fn into_inner(self) -> Vec { self.jobs } @@ -51,7 +51,7 @@ impl Jobs { ); let num_jobs = partial.jobs.len(); - let jobs: anyhow::Result> = partial + let jobs: anyhow::Result> = partial .jobs .iter() .enumerate() @@ -71,7 +71,7 @@ impl Jobs { }) .with_context(context)?; - checker.set_tag(input.tag()); + // checker.set_tag(input.tag()); input .try_deserialize(&unprocessed.content, &mut checker) .with_context(context) diff --git a/diskann-benchmark-runner/src/lib.rs b/diskann-benchmark-runner/src/lib.rs index e0c3b2791..724a827f6 100644 --- a/diskann-benchmark-runner/src/lib.rs +++ b/diskann-benchmark-runner/src/lib.rs @@ -11,7 +11,6 @@ mod internal; mod jobs; mod result; -pub mod any; pub mod app; pub mod files; pub mod input; @@ -19,10 +18,9 @@ pub mod output; pub mod registry; pub mod utils; -pub use any::Any; pub use app::App; pub use benchmark::Benchmark; -pub use checker::{CheckDeserialization, Checker}; +pub use checker::Checker; pub use input::Input; pub use output::Output; pub use registry::{Registry, RegistryError}; diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index f19bb93d3..91a82ed34 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -9,7 +9,7 @@ use thiserror::Error; use crate::{ benchmark::{self, Benchmark, FailureScore, MatchScore, Regression}, - input, Any, Checkpoint, Input, Output, + input, Checkpoint, Input, Output, }; /// A registered benchmark entry: a name paired with a type-erased benchmark. @@ -105,7 +105,7 @@ impl Registry { } /// Return `true` if `job` matches with any registered benchmark. Otherwise, return `false`. - pub fn has_match(&self, job: &Any) -> bool { + pub(crate) fn has_match(&self, job: &input::Any) -> bool { self.find_best_match(job).is_some() } @@ -114,9 +114,9 @@ impl Registry { /// Returns the results of the benchmark if successful. /// /// Errors if a suitable method could not be found or if the invoked benchmark failed. - pub fn call( + pub(crate) fn call( &self, - job: &Any, + job: &input::Any, checkpoint: Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { @@ -132,7 +132,7 @@ impl Registry { /// reasons. /// /// Returns `Ok(())` if a match was found. - pub fn debug(&self, job: &Any, max_methods: usize) -> Result<(), Vec> { + pub(crate) fn debug(&self, job: &input::Any, max_methods: usize) -> Result<(), Vec> { if self.has_match(job) { return Ok(()); } @@ -166,7 +166,7 @@ impl Registry { } /// Find the best matching benchmark for `job` by score. - fn find_best_match(&self, job: &Any) -> Option<&RegisteredBenchmark> { + fn find_best_match(&self, job: &input::Any) -> Option<&RegisteredBenchmark> { self.benchmarks .iter() .filter_map(|entry| { @@ -334,14 +334,14 @@ impl RegressionBenchmark<'_> { self.regression.input_tag() } - pub(crate) fn try_match(&self, input: &Any) -> Result { + pub(crate) fn try_match(&self, input: &input::Any) -> Result { self.benchmark.benchmark().try_match(input) } pub(crate) fn check( &self, - tolerance: &Any, - input: &Any, + tolerance: &input::Any, + input: &input::Any, before: &serde_json::Value, after: &serde_json::Value, ) -> anyhow::Result { @@ -360,7 +360,10 @@ pub(crate) struct RegisteredTolerance<'a> { } /// Helper to capture a `Benchmark::description` call into a `String` via `Display`. -struct Capture<'a>(&'a dyn benchmark::internal::Benchmark, Option<&'a Any>); +struct Capture<'a>( + &'a dyn benchmark::internal::Benchmark, + Option<&'a input::Any>, +); impl std::fmt::Display for Capture<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { @@ -386,19 +389,21 @@ mod tests { macro_rules! input { ($T:ident, $tag:literal) => { + #[derive(Debug)] struct $T; impl Input for $T { + type Raw = (); fn tag() -> &'static str { $tag } - fn try_deserialize( - _serialized: &serde_json::Value, - _checker: &mut Checker, - ) -> anyhow::Result { + fn from_raw(_raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result<$T> { unimplemented!("this struct is for test only"); } - fn example() -> anyhow::Result { + fn serialize(&self) -> anyhow::Result { + unimplemented!("this struct is for test only"); + } + fn example() -> Self::Raw { unimplemented!("this struct is for test only"); } } diff --git a/diskann-benchmark-runner/src/result.rs b/diskann-benchmark-runner/src/result.rs index cd8e34bb8..f80b44581 100644 --- a/diskann-benchmark-runner/src/result.rs +++ b/diskann-benchmark-runner/src/result.rs @@ -267,9 +267,9 @@ mod tests { let savepath = path.join("output.json"); let inputs = [ - TypeInput::new(DataType::Float32, 1, false), - TypeInput::new(DataType::Float16, 2, false), - TypeInput::new(DataType::Float64, 3, false), + TypeInput::new(DataType::Float32, 1), + TypeInput::new(DataType::Float16, 2), + TypeInput::new(DataType::Float64, 3), ]; let serialized: Vec<_> = inputs diff --git a/diskann-benchmark-runner/src/test/dim.rs b/diskann-benchmark-runner/src/test/dim.rs index edc033a38..d70456883 100644 --- a/diskann-benchmark-runner/src/test/dim.rs +++ b/diskann-benchmark-runner/src/test/dim.rs @@ -9,7 +9,7 @@ use serde::{Deserialize, Serialize}; use crate::{ benchmark::{FailureScore, MatchScore, PassFail, Regression}, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, + Benchmark, Checker, Checkpoint, Input, Output, }; /////////// @@ -32,25 +32,22 @@ impl DimInput { } impl Input for DimInput { + type Raw = Self; + fn tag() -> &'static str { "test-input-dim" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(DimInput::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { - Ok(serde_json::to_value(DimInput::new(Some(128)))?) + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) } -} -impl CheckDeserialization for DimInput { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { - Ok(()) + fn example() -> Self::Raw { + DimInput::new(Some(128)) } } @@ -65,23 +62,25 @@ pub(super) struct Tolerance { } impl Input for Tolerance { + type Raw = Self; + fn tag() -> &'static str { "test-input-dim-tolerance" } - fn try_deserialize( - serialized: &serde_json::Value, - _checker: &mut Checker, - ) -> anyhow::Result { - Ok(Any::new(Self::deserialize(serialized)?, Self::tag())) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { - let this = Self { + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self::Raw { + Self { succeed: true, error_in_check: false, - }; - Ok(serde_json::to_value(this)?) + } } } diff --git a/diskann-benchmark-runner/src/test/typed.rs b/diskann-benchmark-runner/src/test/typed.rs index b34d4301d..f3c0c3885 100644 --- a/diskann-benchmark-runner/src/test/typed.rs +++ b/diskann-benchmark-runner/src/test/typed.rs @@ -10,32 +10,30 @@ use serde::{Deserialize, Serialize}; use crate::{ benchmark::{FailureScore, MatchScore, PassFail, Regression}, utils::datatype::{AsDataType, DataType}, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Output, + Benchmark, Checker, Checkpoint, Input, Output, }; /////////// // Input // /////////// -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub(crate) struct TypeInput { pub(super) data_type: DataType, pub(super) dim: usize, - // Should we return an error when `check_deserialization` is called? - pub(super) error_when_checked: bool, - // A flag to verify that [`CheckDeserialization`] has run. - #[serde(skip)] - pub(crate) checked: bool, +} + +#[derive(Serialize, Deserialize)] +pub(crate) struct TypeInputRaw { + data_type: DataType, + dim: usize, + // Should we return an error when deserializing? + error_when_checked: bool, } impl TypeInput { - pub(crate) fn new(data_type: DataType, dim: usize, error_when_checked: bool) -> Self { - Self { - data_type, - dim, - error_when_checked, - checked: false, - } + pub(crate) fn new(data_type: DataType, dim: usize) -> Self { + Self { data_type, dim } } fn run(&self) -> &'static str { @@ -44,33 +42,29 @@ impl TypeInput { } impl Input for TypeInput { + type Raw = TypeInputRaw; + fn tag() -> &'static str { "test-input-types" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(TypeInput::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + if raw.error_when_checked { + Err(anyhow::anyhow!("test input erroring when checked")) + } else { + Ok(Self::new(raw.data_type, raw.dim)) + } } - fn example() -> anyhow::Result { - Ok(serde_json::to_value(TypeInput::new( - DataType::Float32, - 128, - false, - ))?) + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) } -} -impl CheckDeserialization for TypeInput { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { - if self.error_when_checked { - Err(anyhow::anyhow!("test input erroring when checked")) - } else { - self.checked = true; - Ok(()) + fn example() -> Self::Raw { + TypeInputRaw { + data_type: DataType::Float32, + dim: 128, + error_when_checked: false, } } } @@ -81,42 +75,32 @@ impl CheckDeserialization for TypeInput { #[derive(Debug, Serialize, Deserialize)] pub(super) struct Tolerance { - // Should we return an error when `check_deserialization` is called? + // Should we return an error when `from_raw` is called? pub(super) error_when_checked: bool, - - // A flag to verify that [`CheckDeserialization`] has run. - #[serde(skip)] - pub(crate) checked: bool, } impl Input for Tolerance { + type Raw = Self; + fn tag() -> &'static str { "test-input-types-tolerance" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(Self::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + if raw.error_when_checked { + Err(anyhow::anyhow!("test input erroring when checked")) + } else { + Ok(raw) + } } - fn example() -> anyhow::Result { - let this = Self { - error_when_checked: false, - checked: false, - }; - Ok(serde_json::to_value(this)?) + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) } -} -impl CheckDeserialization for Tolerance { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { - if self.error_when_checked { - Err(anyhow::anyhow!("test input erroring when checked")) - } else { - self.checked = true; - Ok(()) + fn example() -> Self::Raw { + Self { + error_when_checked: false, } } } diff --git a/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt b/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt index ba72e9bbf..7e26fb341 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt +++ b/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt @@ -2,8 +2,7 @@ Could not find a match for the following input: { "data_type": "float16", - "dim": 128, - "error_when_checked": false + "dim": 128 } Closest matches: diff --git a/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt b/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt index 34be87554..85da0f389 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt +++ b/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt @@ -2,8 +2,7 @@ Could not find a match for the following input: { "data_type": "float16", - "dim": 1000, - "error_when_checked": false + "dim": 1000 } Closest matches: diff --git a/diskann-benchmark-runner/tests/benchmark/test-overload-0/output.json b/diskann-benchmark-runner/tests/benchmark/test-overload-0/output.json index 8fdbaa7e4..d53fdfdd2 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-overload-0/output.json +++ b/diskann-benchmark-runner/tests/benchmark/test-overload-0/output.json @@ -3,8 +3,7 @@ "input": { "content": { "data_type": "float32", - "dim": 1000, - "error_when_checked": false + "dim": 1000 }, "type": "test-input-types" }, @@ -14,8 +13,7 @@ "input": { "content": { "data_type": "float32", - "dim": 128, - "error_when_checked": false + "dim": 128 }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/benchmark/test-success-0/output.json b/diskann-benchmark-runner/tests/benchmark/test-success-0/output.json index 5b15f5ac2..398b4f358 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-success-0/output.json +++ b/diskann-benchmark-runner/tests/benchmark/test-success-0/output.json @@ -21,8 +21,7 @@ "input": { "content": { "data_type": "float32", - "dim": 128, - "error_when_checked": false + "dim": 128 }, "type": "test-input-types" }, @@ -32,8 +31,7 @@ "input": { "content": { "data_type": "int8", - "dim": 128, - "error_when_checked": false + "dim": 128 }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/regression/check-run-error-0/output.json b/diskann-benchmark-runner/tests/regression/check-run-error-0/output.json index 9bad43329..f79fc4729 100644 --- a/diskann-benchmark-runner/tests/regression/check-run-error-0/output.json +++ b/diskann-benchmark-runner/tests/regression/check-run-error-0/output.json @@ -12,8 +12,7 @@ "input": { "content": { "data_type": "int8", - "dim": 128, - "error_when_checked": false + "dim": 128 }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/regression/check-run-error-2/output.json b/diskann-benchmark-runner/tests/regression/check-run-error-2/output.json index 869ae3cc8..4f93315f1 100644 --- a/diskann-benchmark-runner/tests/regression/check-run-error-2/output.json +++ b/diskann-benchmark-runner/tests/regression/check-run-error-2/output.json @@ -3,8 +3,7 @@ "input": { "content": { "data_type": "float32", - "dim": 128, - "error_when_checked": false + "dim": 128 }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/regression/check-run-fail-0/output.json b/diskann-benchmark-runner/tests/regression/check-run-fail-0/output.json index 5112717a9..77c4fa441 100644 --- a/diskann-benchmark-runner/tests/regression/check-run-fail-0/output.json +++ b/diskann-benchmark-runner/tests/regression/check-run-fail-0/output.json @@ -12,8 +12,7 @@ "input": { "content": { "data_type": "int8", - "dim": 128, - "error_when_checked": false + "dim": 128 }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/regression/check-run-pass-0/output.json b/diskann-benchmark-runner/tests/regression/check-run-pass-0/output.json index 5386e9dee..540e94d1b 100644 --- a/diskann-benchmark-runner/tests/regression/check-run-pass-0/output.json +++ b/diskann-benchmark-runner/tests/regression/check-run-pass-0/output.json @@ -12,8 +12,7 @@ "input": { "content": { "data_type": "float32", - "dim": 1000, - "error_when_checked": false + "dim": 1000 }, "type": "test-input-types" }, @@ -23,8 +22,7 @@ "input": { "content": { "data_type": "int8", - "dim": 128, - "error_when_checked": false + "dim": 128 }, "type": "test-input-types" }, diff --git a/diskann-benchmark-simd/src/lib.rs b/diskann-benchmark-simd/src/lib.rs index 9fe99e9e9..cf12ae373 100644 --- a/diskann-benchmark-simd/src/lib.rs +++ b/diskann-benchmark-simd/src/lib.rs @@ -26,7 +26,7 @@ use diskann_benchmark_runner::{ num::{relative_change, NonNegativeFinite}, percentiles, MicroSeconds, }, - Any, Benchmark, CheckDeserialization, Checker, Input, Registry, + Benchmark, Checker, Input, Registry, }; //////////////// @@ -118,12 +118,6 @@ pub struct SimdOp { runs: Vec, } -impl CheckDeserialization for SimdOp { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { - Ok(()) - } -} - macro_rules! write_field { ($f:ident, $field:tt, $($expr:tt)*) => { writeln!($f, "{:>18}: {}", $field, $($expr)*) @@ -149,18 +143,21 @@ impl std::fmt::Display for SimdOp { } impl Input for SimdOp { + type Raw = Self; + fn tag() -> &'static str { "simd-op" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(Self::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self::Raw { const DIM: [NonZeroUsize; 2] = [ NonZeroUsize::new(128).unwrap(), NonZeroUsize::new(150).unwrap(), @@ -191,12 +188,12 @@ impl Input for SimdOp { }, ]; - Ok(serde_json::to_value(&Self { + Self { query_type: DataType::Float32, data_type: DataType::Float32, arch: Arch::X86_64_V3, runs, - })?) + } } } @@ -213,33 +210,30 @@ struct SimdTolerance { min_time_regression: NonNegativeFinite, } -impl CheckDeserialization for SimdTolerance { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { - Ok(()) - } -} - impl Input for SimdTolerance { + type Raw = Self; + fn tag() -> &'static str { "simd-tolerance" } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(Self::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self { const EXAMPLE: NonNegativeFinite = match NonNegativeFinite::new(0.10) { Ok(v) => v, Err(_) => panic!("use a non-negative finite please"), }; - Ok(serde_json::to_value(SimdTolerance { + SimdTolerance { min_time_regression: EXAMPLE, - })?) + } } } diff --git a/diskann-benchmark/src/backend/disk_index/benchmarks.rs b/diskann-benchmark/src/backend/disk_index/benchmarks.rs index f41e39346..c81022b97 100644 --- a/diskann-benchmark/src/backend/disk_index/benchmarks.rs +++ b/diskann-benchmark/src/backend/disk_index/benchmarks.rs @@ -15,7 +15,7 @@ use diskann_benchmark_runner::{ fmt::Table, num::{relative_change, NonNegativeFinite}, }, - Any, Benchmark, CheckDeserialization, Checker, Checkpoint, Input, Registry, + Benchmark, Checker, Checkpoint, Input, Registry, }; use diskann_providers::storage::FileStorageProvider; use half::f16; @@ -165,25 +165,22 @@ impl DiskIndexTolerance { } } -impl CheckDeserialization for DiskIndexTolerance { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { - Ok(()) - } -} - impl Input for DiskIndexTolerance { + type Raw = Self; + fn tag() -> &'static str { Self::tag() } - fn try_deserialize( - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - checker.any(Self::deserialize(serialized)?) + fn from_raw(raw: Self::Raw, _checker: &mut Checker) -> anyhow::Result { + Ok(raw) } - fn example() -> anyhow::Result { + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + fn example() -> Self { const DEFAULT: NonNegativeFinite = match NonNegativeFinite::new(0.10) { Ok(v) => v, Err(_) => panic!("use a non-negative finite value"), @@ -193,7 +190,7 @@ impl Input for DiskIndexTolerance { Err(_) => panic!("use a non-negative finite value"), }; - Ok(serde_json::to_value(DiskIndexTolerance { + DiskIndexTolerance { build_time_regression: DEFAULT, qps_regression: DEFAULT, recall_regression: RECALL, @@ -201,7 +198,7 @@ impl Input for DiskIndexTolerance { mean_comps_regression: DEFAULT, mean_latency_regression: DEFAULT, p95_latency_regression: DEFAULT, - })?) + } } } diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 00f6067d4..473d7982b 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -6,9 +6,7 @@ use std::{fmt, num::NonZeroUsize, path::Path}; use anyhow::Context; -use diskann_benchmark_runner::{ - files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, -}; +use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; #[cfg(feature = "disk-index")] use diskann_disk::QuantizationType; use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file}; @@ -88,29 +86,19 @@ impl DiskIndexOperation { pub(crate) const fn tag() -> &'static str { "disk-index" } -} -/////////////////////////// -// Check Deserialization // -/////////////////////////// - -impl CheckDeserialization for DiskIndexOperation { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // validate the source + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { match &mut self.source { - DiskIndexSource::Load(load) => load.check_deserialization(checker)?, - DiskIndexSource::Build(build) => build.check_deserialization(checker)?, + DiskIndexSource::Load(load) => load.validate(checker)?, + DiskIndexSource::Build(build) => build.validate(checker)?, } - - // validate the search phase - self.search_phase.check_deserialization(checker)?; - + self.search_phase.validate(checker)?; Ok(()) } } -impl CheckDeserialization for DiskIndexLoad { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { +impl DiskIndexLoad { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { let files = [ (get_pq_pivot_file(&self.load_path), "pq pivot file"), ( @@ -131,12 +119,9 @@ impl CheckDeserialization for DiskIndexLoad { } } -impl CheckDeserialization for DiskIndexBuild { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // file input - self.data - .check_deserialization(checker) - .context("invalid data file")?; +impl DiskIndexBuild { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.resolve(checker).context("invalid data file")?; // basic constraints if self.dim == 0 { @@ -183,18 +168,16 @@ impl CheckDeserialization for DiskIndexBuild { } } -impl CheckDeserialization for DiskSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // inputs +impl DiskSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { self.queries - .check_deserialization(checker) + .resolve(checker) .context("invalid queries file")?; self.groundtruth - .check_deserialization(checker) + .resolve(checker) .context("invalid groundtruth file")?; if let Some(vf) = self.vector_filters_file.as_mut() { - vf.check_deserialization(checker) - .context("invalid vector_filters_file")?; + vf.resolve(checker).context("invalid vector_filters_file")?; } // basic numeric sanity checks diff --git a/diskann-benchmark/src/inputs/exhaustive.rs b/diskann-benchmark/src/inputs/exhaustive.rs index d73bc1491..14fd06336 100644 --- a/diskann-benchmark/src/inputs/exhaustive.rs +++ b/diskann-benchmark/src/inputs/exhaustive.rs @@ -6,9 +6,7 @@ use std::num::NonZeroUsize; use anyhow::{anyhow, Context}; -use diskann_benchmark_runner::{ - files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, -}; +use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; use serde::{Deserialize, Serialize}; use crate::{ @@ -41,8 +39,8 @@ pub(crate) struct SearchValues { pub(crate) recall_n: Vec, } -impl CheckDeserialization for SearchValues { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { +impl SearchValues { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { // Ensure that both `recall_k` and `recall_n` are non-empty. if self.recall_k.is_empty() { return Err(anyhow!("field `recall_k` cannot be empty")); @@ -96,12 +94,11 @@ pub(crate) struct SearchPhase { pub(crate) recalls: SearchValues, } -impl CheckDeserialization for SearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - self.groundtruth.check_deserialization(checker)?; - self.recalls.check_deserialization(checker)?; +impl SearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.groundtruth.resolve(checker)?; + self.recalls.validate(checker)?; Ok(()) } } @@ -219,12 +216,10 @@ impl Product { pub(crate) const fn tag() -> &'static str { "exhaustive-product-quantization" } -} -impl CheckDeserialization for Product { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.data.check_deserialization(checker)?; - self.search.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.data.resolve(checker)?; + self.search.validate(checker)?; // Chcck that provided data type is compatible with `f32`. f32::check_converting_load(self.data_type)?; @@ -368,8 +363,8 @@ impl std::fmt::Display for PreScale { } } -impl CheckDeserialization for PreScale { - fn check_deserialization(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { +impl PreScale { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { if let Self::Some(v) = self { if *v <= 0.0 { anyhow::bail!("pre-scaling {} must be positive", v); @@ -401,12 +396,10 @@ impl Spherical { pub(crate) const fn tag() -> &'static str { "exhaustive-spherical-quantization" } -} -impl CheckDeserialization for Spherical { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.data.check_deserialization(checker)?; - self.search.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.data.resolve(checker)?; + self.search.validate(checker)?; // Chcck that provided data type is compatible with `f32`. f32::check_converting_load(self.data_type)?; @@ -422,7 +415,7 @@ impl CheckDeserialization for Spherical { })?; } - self.pre_scale.check_deserialization(checker)?; + self.pre_scale.validate(checker)?; Ok(()) } } @@ -504,12 +497,10 @@ impl MinMax { pub(crate) const fn tag() -> &'static str { "exhaustive-minmax-quantization" } -} -impl CheckDeserialization for MinMax { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.data.check_deserialization(checker)?; - self.search.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.data.resolve(checker)?; + self.search.validate(checker)?; // Chcck that provided data type is compatible with `f32`. f32::check_converting_load(self.data_type)?; diff --git a/diskann-benchmark/src/inputs/filters.rs b/diskann-benchmark/src/inputs/filters.rs index 09fdaf919..942c6da12 100644 --- a/diskann-benchmark/src/inputs/filters.rs +++ b/diskann-benchmark/src/inputs/filters.rs @@ -3,7 +3,7 @@ * Licensed under the MIT license. */ -use diskann_benchmark_runner::{files::InputFile, CheckDeserialization, Checker}; +use diskann_benchmark_runner::{files::InputFile, Checker}; use serde::{Deserialize, Serialize}; use crate::inputs::{as_input, Example}; @@ -55,17 +55,10 @@ impl MetadataIndexBuild { pub(crate) const fn tag() -> &'static str { "metadata-index-build" } -} -impl CheckDeserialization for MetadataIndexBuild { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Validate filter parameters (which include the paths to queries and label files) - self.filter_params - .data_labels - .check_deserialization(checker)?; - self.filter_params - .query_predicates - .check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.filter_params.data_labels.resolve(checker)?; + self.filter_params.query_predicates.resolve(checker)?; Ok(()) } } diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 95cb89484..9135569e8 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -11,9 +11,7 @@ use diskann::{ utils::IntoUsize, }; use diskann_benchmark_core::streaming::executors::bigann; -use diskann_benchmark_runner::{ - files::InputFile, utils::datatype::DataType, CheckDeserialization, Checker, -}; +use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker}; use diskann_providers::{ model::{ configuration::IndexConfiguration, @@ -50,8 +48,8 @@ pub(crate) struct GraphSearch { pub(crate) recall_k: usize, } -impl CheckDeserialization for GraphSearch { - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { +impl GraphSearch { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { for (i, l) in self.search_l.iter().enumerate() { if *l < self.search_n { return Err(anyhow!( @@ -97,9 +95,8 @@ impl GraphRangeSearch { } } -impl CheckDeserialization for GraphRangeSearch { - // all necessary checks are carried out when Range is initialized - fn check_deserialization(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { +impl GraphRangeSearch { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> Result<(), anyhow::Error> { self.construct_params() .context("invalid range search params")?; @@ -117,14 +114,12 @@ pub(crate) struct TopkSearchPhase { pub(crate) runs: Vec, } -impl CheckDeserialization for TopkSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - - self.groundtruth.check_deserialization(checker)?; +impl TopkSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.groundtruth.resolve(checker)?; for (i, run) in self.runs.iter_mut().enumerate() { - run.check_deserialization(checker) + run.validate(checker) .with_context(|| format!("search run {}", i))?; } @@ -169,14 +164,12 @@ pub(crate) struct RangeSearchPhase { pub(crate) runs: Vec, } -impl CheckDeserialization for RangeSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - - self.groundtruth.check_deserialization(checker)?; +impl RangeSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.groundtruth.resolve(checker)?; for (i, run) in self.runs.iter_mut().enumerate() { - run.check_deserialization(checker) + run.validate(checker) .with_context(|| format!("search run {}", i))?; } @@ -197,13 +190,11 @@ pub(crate) struct BetaSearchPhase { pub(crate) runs: Vec, } -impl CheckDeserialization for BetaSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - - self.query_predicates.check_deserialization(checker)?; - self.data_labels.check_deserialization(checker)?; +impl BetaSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.query_predicates.resolve(checker)?; + self.data_labels.resolve(checker)?; if self.beta <= 0.0 || self.beta > 1.0 { return Err(anyhow::anyhow!( @@ -212,9 +203,9 @@ impl CheckDeserialization for BetaSearchPhase { )); } - self.groundtruth.check_deserialization(checker)?; + self.groundtruth.resolve(checker)?; for (i, run) in self.runs.iter_mut().enumerate() { - run.check_deserialization(checker) + run.validate(checker) .with_context(|| format!("search run {}", i))?; } @@ -234,17 +225,14 @@ pub(crate) struct MultiHopSearchPhase { pub(crate) runs: Vec, } -impl CheckDeserialization for MultiHopSearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.queries.check_deserialization(checker)?; - - self.query_predicates.check_deserialization(checker)?; - self.data_labels.check_deserialization(checker)?; - - self.groundtruth.check_deserialization(checker)?; +impl MultiHopSearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.queries.resolve(checker)?; + self.query_predicates.resolve(checker)?; + self.data_labels.resolve(checker)?; + self.groundtruth.resolve(checker)?; for (i, run) in self.runs.iter_mut().enumerate() { - run.check_deserialization(checker) + run.validate(checker) .with_context(|| format!("search run {}", i))?; } @@ -394,13 +382,13 @@ impl SearchPhase { } } -impl CheckDeserialization for SearchPhase { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { +impl SearchPhase { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { match self { - SearchPhase::Topk(phase) => phase.check_deserialization(checker), - SearchPhase::Range(phase) => phase.check_deserialization(checker), - SearchPhase::TopkBetaFilter(phase) => phase.check_deserialization(checker), - SearchPhase::TopkMultihopFilter(phase) => phase.check_deserialization(checker), + SearchPhase::Topk(phase) => phase.validate(checker), + SearchPhase::Range(phase) => phase.validate(checker), + SearchPhase::TopkBetaFilter(phase) => phase.validate(checker), + SearchPhase::TopkMultihopFilter(phase) => phase.validate(checker), } } } @@ -482,10 +470,8 @@ impl IndexLoad { write_field!(f, "Load Path", self.load_path)?; Ok(()) } -} -impl CheckDeserialization for IndexLoad { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { // Check if the file exists (allowing for relative paths with respect to the current // directory. // @@ -652,12 +638,9 @@ impl IndexBuild { } Ok(()) } -} -impl CheckDeserialization for IndexBuild { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.data.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.data.resolve(checker)?; // We allow overwriting of already existing save paths, since users like to do this // The save path must either (1) be an absolute path, in which case we check that its parent directory exists @@ -720,18 +703,14 @@ impl IndexSource { IndexSource::Build(build) => &build.data_type, } } -} -impl CheckDeserialization for IndexSource { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { match self { - IndexSource::Load(load) => load.check_deserialization(checker), - IndexSource::Build(build) => build.check_deserialization(checker), + IndexSource::Load(load) => load.validate(checker), + IndexSource::Build(build) => build.validate(checker), } } -} -impl IndexSource { fn summarize_fields(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { IndexSource::Load(load) => load.summarize_fields(f), @@ -750,13 +729,10 @@ impl IndexOperation { pub(crate) const fn tag() -> &'static str { "graph-index-build" } -} -impl CheckDeserialization for IndexOperation { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Check the validity of the input files. - self.source.check_deserialization(checker)?; - self.search_phase.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + self.source.validate(checker)?; + self.search_phase.validate(checker)?; Ok(()) } @@ -832,11 +808,9 @@ impl IndexPQOperation { IndexSource::Build(b) => Ok(b.inmem_parameters(num_points, dim)), } } -} -impl CheckDeserialization for IndexPQOperation { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.index_operation.check_deserialization(checker) + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.index_operation.validate(checker) } } @@ -911,10 +885,8 @@ impl IndexSQOperation { IndexSource::Build(b) => Ok(b.inmem_parameters(num_points, dim)), } } -} -impl CheckDeserialization for IndexSQOperation { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { if self.standard_deviations <= 0.0 { return Err(anyhow::anyhow!( "scalar quantization standard deviations ({}) must be strictly positive", @@ -922,7 +894,7 @@ impl CheckDeserialization for IndexSQOperation { )); } - self.index_operation.check_deserialization(checker) + self.index_operation.validate(checker) } } @@ -994,12 +966,10 @@ impl SphericalQuantBuild { ) -> DefaultProviderParameters { self.build.inmem_parameters(num_points, dim) } -} -impl CheckDeserialization for SphericalQuantBuild { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.build.check_deserialization(checker)?; - self.search_phase.check_deserialization(checker)?; + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.build.validate(checker)?; + self.search_phase.validate(checker)?; if self.build.save_path.is_some() { return Err(anyhow::anyhow!( @@ -1021,7 +991,7 @@ impl CheckDeserialization for SphericalQuantBuild { } if let Some(pre_scale) = &mut self.pre_scale { - pre_scale.check_deserialization(checker)?; + pre_scale.validate(checker)?; } Ok(()) @@ -1125,9 +1095,9 @@ pub(crate) struct DynamicRunbookParams { // 1. The runbook file can be parsed // 2. The dataset_name exists in the runbook // 3. All required ground truth files exist in gt_directory -impl CheckDeserialization for DynamicRunbookParams { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.runbook_path.check_deserialization(checker)?; +impl DynamicRunbookParams { + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.runbook_path.resolve(checker)?; // Validate consolidate_threshold is greater than 0 if self.consolidate_threshold <= 0.0 { @@ -1256,6 +1226,12 @@ impl DynamicIndexRun { "graph-index-dynamic-run" } + pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { + self.build.validate(checker)?; + self.runbook_params.validate(checker)?; + self.search_phase.validate(checker)?; + Ok(()) + } pub(crate) fn try_as_config(&self, insert_l: usize) -> anyhow::Result { let mut builder = self.build.try_as_config()?; builder.l_build(insert_l); @@ -1271,15 +1247,6 @@ impl DynamicIndexRun { } } -impl CheckDeserialization for DynamicIndexRun { - fn check_deserialization(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.build.check_deserialization(checker)?; - self.runbook_params.check_deserialization(checker)?; - self.search_phase.check_deserialization(checker)?; - Ok(()) - } -} - impl Example for DynamicIndexRun { fn example() -> Self { let build = IndexBuild::example(); diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index 7875beb1d..89041614e 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -14,29 +14,29 @@ pub(crate) trait Example { fn example() -> Self; } -// NOTE: The input registration and dispatching isn't prefect. It uses a pattern (like -// the use of `'static` on the benchmark types) as a byproduct of older ways of doing -// benchmark selection. -// -// In the future, these can be migrated to reduce this legacy cruft. macro_rules! as_input { ($T:ty) => { impl diskann_benchmark_runner::Input for $T { + type Raw = $T; + fn tag() -> &'static str { <$T>::tag() } - fn try_deserialize( - serialized: &serde_json::Value, + fn from_raw( + mut raw: Self::Raw, checker: &mut diskann_benchmark_runner::Checker, - ) -> anyhow::Result { - checker.any(<$T as serde::Deserialize>::deserialize(serialized)?) + ) -> anyhow::Result { + raw.validate(checker)?; + Ok(raw) + } + + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) } - fn example() -> anyhow::Result { - Ok(serde_json::to_value( - <$T as $crate::inputs::Example>::example(), - )?) + fn example() -> Self { + <$T as $crate::inputs::Example>::example() } } }; From 3ac2ae4be9c38ba2e539d408e57039e3a56de72d Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 17:44:51 -0700 Subject: [PATCH 6/9] Last cleanups. --- diskann-benchmark-runner/src/benchmark.rs | 10 +- diskann-benchmark-runner/src/checker.rs | 4 - diskann-benchmark-runner/src/input.rs | 230 +++++++++--------- .../src/internal/regression.rs | 3 +- diskann-benchmark-runner/src/jobs.rs | 9 +- diskann-benchmark-runner/src/registry.rs | 37 +-- diskann-benchmark/README.md | 81 +++--- diskann-benchmark/src/inputs/graph_index.rs | 1 + diskann-benchmark/src/inputs/mod.rs | 4 + 9 files changed, 194 insertions(+), 185 deletions(-) diff --git a/diskann-benchmark-runner/src/benchmark.rs b/diskann-benchmark-runner/src/benchmark.rs index 1130fefaa..dbdfe8063 100644 --- a/diskann-benchmark-runner/src/benchmark.rs +++ b/diskann-benchmark-runner/src/benchmark.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; -use crate::{input::Any, Checkpoint, Input, Output}; +use crate::{Checkpoint, Input, Output}; /// A registered benchmark. /// @@ -134,6 +134,8 @@ pub enum PassFail { pub(crate) mod internal { use super::*; + use crate::input::internal::Any; + use anyhow::Context; use thiserror::Error; @@ -179,7 +181,7 @@ pub(crate) mod internal { pub(crate) type CheckedPassFail = PassFail; pub(crate) trait Regression { - fn tolerance(&self) -> &dyn crate::input::DynInput; + fn tolerance(&self) -> &dyn crate::input::internal::DynInput; fn input_tag(&self) -> &'static str; fn check( &self, @@ -228,8 +230,8 @@ pub(crate) mod internal { where T: super::Regression, { - fn tolerance(&self) -> &dyn crate::input::DynInput { - &crate::input::Wrapper::::INSTANCE + fn tolerance(&self) -> &dyn crate::input::internal::DynInput { + &crate::input::internal::Wrapper::::INSTANCE } fn input_tag(&self) -> &'static str { diff --git a/diskann-benchmark-runner/src/checker.rs b/diskann-benchmark-runner/src/checker.rs index da67fb113..03ca6dec0 100644 --- a/diskann-benchmark-runner/src/checker.rs +++ b/diskann-benchmark-runner/src/checker.rs @@ -145,10 +145,6 @@ impl Checker { self.search_directories(), ))) } - - // pub(crate) fn set_tag(&mut self, tag: &'static str) { - // let _ = self.tag.insert(tag); - // } } /////////// diff --git a/diskann-benchmark-runner/src/input.rs b/diskann-benchmark-runner/src/input.rs index 203cd6f1e..59ae6eba4 100644 --- a/diskann-benchmark-runner/src/input.rs +++ b/diskann-benchmark-runner/src/input.rs @@ -5,7 +5,17 @@ use crate::Checker; +/// Inputs to [`Benchmarks`](crate::Benchmark). +/// +/// These begin as [`raw`](Self::Raw) data transfer objects before final construction via +/// [`from_raw`](Self::from_raw). pub trait Input: Sized + std::fmt::Debug + 'static { + /// The raw form of this input that is deserialized from input files and serialized as + /// [`examples`](Self::example). The raw nature of this type reflects that no input + /// validation has been performed beyond the checks performed by its + /// [`Deserialize`](serde::Deserialize) implementation. + /// + /// Final object validation is performed via [`from_raw`](Self::from_raw). type Raw: serde::de::DeserializeOwned + serde::Serialize; /// Return the discriminant associated with this type. @@ -15,36 +25,22 @@ pub trait Input: Sized + std::fmt::Debug + 'static { /// Well formed implementations should always return the same result. fn tag() -> &'static str; - /// Attempt to deserialize an opaque object from the raw `serialized` representation. - /// - /// Deserialized values can be constructed and returned via [`Checker::any`], - /// [`Any::new`] or [`Any::raw`]. - /// - /// If using the [`Any`] constructors directly, implementations should associate - /// [`Self::tag`] with the returned `Any`. If [`Checker::any`] is used - this will - /// happen automatically. - /// - /// Implementations are **strongly** encouraged to implement - /// [`CheckDeserialization`](crate::CheckDeserialization) and use this API to ensure - /// shared resources (like input files or output files) are correctly resolved and - /// properly shared among all jobs in a benchmark run. + /// Construct `Self` from the raw deserialized representation, performing any necessary + /// validation checks (e.g., resolving file paths via the [`Checker`]). fn from_raw(raw: Self::Raw, checker: &mut Checker) -> anyhow::Result; /// Serialize `self` to a [`serde_json::Value`]. fn serialize(&self) -> anyhow::Result; - /// Print an example JSON representation of objects this input is expected to parse. + /// Return an example of a raw input for this [`Input`]. /// - /// Well-formed implementations should ensure that passing the returned - /// [`serde_json::Value`] back to [`Self::try_deserialize`] correctly deserializes, - /// though it need not necessarily pass - /// [`CheckDeserialization`](crate::CheckDeserialization). + /// This is used to supply sample JSON layouts in the benchmark CLI. fn example() -> Self::Raw; } /// A registered input. See [`crate::Registry::input`]. #[derive(Clone, Copy)] -pub struct Registered<'a>(pub(crate) &'a dyn DynInput); +pub struct Registered<'a>(pub(crate) &'a dyn internal::DynInput); impl Registered<'_> { /// Return the input tag of the registered input. @@ -56,12 +52,12 @@ impl Registered<'_> { /// Try to deserialize raw JSON into the dynamic type of the input. /// - /// See: [`Input::try_deserialize`]. + /// See: [`Input::from_raw`]. pub(crate) fn try_deserialize( &self, serialized: &serde_json::Value, checker: &mut Checker, - ) -> anyhow::Result { + ) -> anyhow::Result { self.0.try_deserialize(serialized, checker) } @@ -81,123 +77,123 @@ impl std::fmt::Debug for Registered<'_> { } } -////////////// -// Internal // -////////////// +pub(crate) mod internal { + use super::*; -/// Runtime representation of a deserialized [`Input`]. -#[derive(Debug)] -pub(crate) struct Any { - any: Box, -} + /// Runtime representation of a deserialized [`Input`]. + #[derive(Debug)] + pub(crate) struct Any { + any: Box, + } -impl Any { - pub(crate) fn new(input: T) -> Self - where - T: Input + std::fmt::Debug + 'static, - { - Self { - any: Box::new(input), + impl Any { + pub(crate) fn new(input: T) -> Self + where + T: Input + std::fmt::Debug + 'static, + { + Self { + any: Box::new(input), + } } - } - #[must_use = "this function has no side effects"] - pub(crate) fn tag(&self) -> &'static str { - self.any.tag() - } + #[must_use = "this function has no side effects"] + pub(crate) fn tag(&self) -> &'static str { + self.any.tag() + } - #[must_use = "this function has no side effects"] - pub(crate) fn downcast_ref(&self) -> Option<&T> - where - T: std::any::Any, - { - self.any.as_any().downcast_ref::() - } + #[must_use = "this function has no side effects"] + pub(crate) fn downcast_ref(&self) -> Option<&T> + where + T: std::any::Any, + { + self.any.as_any().downcast_ref::() + } - #[must_use = "this function has no side effects"] - pub(crate) fn serialize(&self) -> anyhow::Result { - self.any.serialize() + #[must_use = "this function has no side effects"] + pub(crate) fn serialize(&self) -> anyhow::Result { + self.any.serialize() + } } -} - -trait RuntimeAny: std::fmt::Debug { - fn tag(&self) -> &'static str; - fn as_any(&self) -> &dyn std::any::Any; - fn serialize(&self) -> anyhow::Result; -} -impl RuntimeAny for T -where - T: Input + std::fmt::Debug + 'static, -{ - fn tag(&self) -> &'static str { - ::tag() + trait RuntimeAny: std::fmt::Debug { + fn tag(&self) -> &'static str; + fn as_any(&self) -> &dyn std::any::Any; + fn serialize(&self) -> anyhow::Result; } - fn as_any(&self) -> &dyn std::any::Any { - self - } + impl RuntimeAny for T + where + T: Input, + { + fn tag(&self) -> &'static str { + ::tag() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } - fn serialize(&self) -> anyhow::Result { - ::serialize(self) + fn serialize(&self) -> anyhow::Result { + ::serialize(self) + } } -} -#[derive(Debug)] -pub(crate) struct Wrapper(std::marker::PhantomData); + #[derive(Debug)] + pub(crate) struct Wrapper(std::marker::PhantomData); -impl Wrapper { - pub(crate) const INSTANCE: Self = Self::new(); + impl Wrapper { + pub(crate) const INSTANCE: Self = Self::new(); - pub(crate) const fn new() -> Self { - Self(std::marker::PhantomData) + pub(crate) const fn new() -> Self { + Self(std::marker::PhantomData) + } } -} -impl Clone for Wrapper { - fn clone(&self) -> Self { - *self + impl Clone for Wrapper { + fn clone(&self) -> Self { + *self + } } -} -impl Copy for Wrapper {} + impl Copy for Wrapper {} -pub(crate) trait DynInput { - fn tag(&self) -> &'static str; - fn try_deserialize( - &self, - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result; - fn example(&self) -> anyhow::Result; - - // reflection - fn as_any(&self) -> &dyn std::any::Any; - fn type_name(&self) -> &'static str; -} + pub(crate) trait DynInput { + fn tag(&self) -> &'static str; + fn try_deserialize( + &self, + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result; + fn example(&self) -> anyhow::Result; -impl DynInput for Wrapper -where - T: Input + std::fmt::Debug + 'static, -{ - fn tag(&self) -> &'static str { - T::tag() - } - fn try_deserialize( - &self, - serialized: &serde_json::Value, - checker: &mut Checker, - ) -> anyhow::Result { - let raw = >::deserialize(serialized)?; - Ok(Any::new(T::from_raw(raw, checker)?)) - } - fn example(&self) -> anyhow::Result { - Ok(serde_json::to_value(T::example())?) - } - fn as_any(&self) -> &dyn std::any::Any { - self + // reflection + fn as_any(&self) -> &dyn std::any::Any; + fn type_name(&self) -> &'static str; } - fn type_name(&self) -> &'static str { - std::any::type_name::() + + impl DynInput for Wrapper + where + T: Input + std::fmt::Debug + 'static, + { + fn tag(&self) -> &'static str { + T::tag() + } + fn try_deserialize( + &self, + serialized: &serde_json::Value, + checker: &mut Checker, + ) -> anyhow::Result { + let raw = >::deserialize(serialized)?; + Ok(Any::new(T::from_raw(raw, checker)?)) + } + fn example(&self) -> anyhow::Result { + Ok(serde_json::to_value(T::example())?) + } + fn as_any(&self) -> &dyn std::any::Any { + self + } + fn type_name(&self) -> &'static str { + std::any::type_name::() + } } } diff --git a/diskann-benchmark-runner/src/internal/regression.rs b/diskann-benchmark-runner/src/internal/regression.rs index eb7289cc7..2df2d922a 100644 --- a/diskann-benchmark-runner/src/internal/regression.rs +++ b/diskann-benchmark-runner/src/internal/regression.rs @@ -100,7 +100,7 @@ use serde_json::Value; use crate::{ benchmark::{internal::CheckedPassFail, PassFail}, - input::Any, + input::internal::Any, internal::load_from_disk, jobs, registry, result, Checker, }; @@ -349,7 +349,6 @@ impl Raw { .with_context(context); } - // checker.set_tag(entry.tolerance.tag()); let tolerance = entry .tolerance .try_deserialize(&unprocessed.tolerance.content, &mut checker) diff --git a/diskann-benchmark-runner/src/jobs.rs b/diskann-benchmark-runner/src/jobs.rs index 9c43acbef..0cc75e751 100644 --- a/diskann-benchmark-runner/src/jobs.rs +++ b/diskann-benchmark-runner/src/jobs.rs @@ -13,17 +13,17 @@ use crate::{checker::Checker, input, Registry}; #[derive(Debug)] pub(crate) struct Jobs { /// The benchmark jobs to execute. - jobs: Vec, + jobs: Vec, } impl Jobs { /// Return the jobs associated with this benchmark run. - pub(crate) fn jobs(&self) -> &[input::Any] { + pub(crate) fn jobs(&self) -> &[input::internal::Any] { &self.jobs } /// Consume `self`, returning the contained list of jobs. - pub(crate) fn into_inner(self) -> Vec { + pub(crate) fn into_inner(self) -> Vec { self.jobs } @@ -51,7 +51,7 @@ impl Jobs { ); let num_jobs = partial.jobs.len(); - let jobs: anyhow::Result> = partial + let jobs: anyhow::Result> = partial .jobs .iter() .enumerate() @@ -71,7 +71,6 @@ impl Jobs { }) .with_context(context)?; - // checker.set_tag(input.tag()); input .try_deserialize(&unprocessed.content, &mut checker) .with_context(context) diff --git a/diskann-benchmark-runner/src/registry.rs b/diskann-benchmark-runner/src/registry.rs index 91a82ed34..4cc4aaa1b 100644 --- a/diskann-benchmark-runner/src/registry.rs +++ b/diskann-benchmark-runner/src/registry.rs @@ -41,7 +41,7 @@ impl RegisteredBenchmark { /// A collection of registered inputs and benchmarks. pub struct Registry { // Inputs keyed by their tag type. - inputs: HashMap<&'static str, Box>, + inputs: HashMap<&'static str, Box>, benchmarks: Vec, } @@ -105,7 +105,7 @@ impl Registry { } /// Return `true` if `job` matches with any registered benchmark. Otherwise, return `false`. - pub(crate) fn has_match(&self, job: &input::Any) -> bool { + pub(crate) fn has_match(&self, job: &input::internal::Any) -> bool { self.find_best_match(job).is_some() } @@ -116,7 +116,7 @@ impl Registry { /// Errors if a suitable method could not be found or if the invoked benchmark failed. pub(crate) fn call( &self, - job: &input::Any, + job: &input::internal::Any, checkpoint: Checkpoint<'_>, output: &mut dyn Output, ) -> anyhow::Result { @@ -132,7 +132,11 @@ impl Registry { /// reasons. /// /// Returns `Ok(())` if a match was found. - pub(crate) fn debug(&self, job: &input::Any, max_methods: usize) -> Result<(), Vec> { + pub(crate) fn debug( + &self, + job: &input::internal::Any, + max_methods: usize, + ) -> Result<(), Vec> { if self.has_match(job) { return Ok(()); } @@ -166,7 +170,7 @@ impl Registry { } /// Find the best matching benchmark for `job` by score. - fn find_best_match(&self, job: &input::Any) -> Option<&RegisteredBenchmark> { + fn find_best_match(&self, job: &input::internal::Any) -> Option<&RegisteredBenchmark> { self.benchmarks .iter() .filter_map(|entry| { @@ -180,7 +184,7 @@ impl Registry { .map(|(entry, _)| entry) } - fn _input(&self, tag: &str) -> Option<&dyn input::DynInput> { + fn _input(&self, tag: &str) -> Option<&dyn input::internal::DynInput> { self.inputs.get(tag).map(|v| &**v) } @@ -189,16 +193,16 @@ impl Registry { T: Input + 'static, { let tag = T::tag(); - let wrapper = crate::input::Wrapper::::new(); + let wrapper = crate::input::internal::Wrapper::::new(); match self.inputs.entry(tag) { Entry::Vacant(v) => { v.insert(Box::new(wrapper)); Ok(()) } Entry::Occupied(o) => { - use input::DynInput; + use input::internal::DynInput; - if o.get().as_any().is::>() { + if o.get().as_any().is::>() { Ok(()) } else { Err(RegistryError { @@ -334,14 +338,17 @@ impl RegressionBenchmark<'_> { self.regression.input_tag() } - pub(crate) fn try_match(&self, input: &input::Any) -> Result { + pub(crate) fn try_match( + &self, + input: &input::internal::Any, + ) -> Result { self.benchmark.benchmark().try_match(input) } pub(crate) fn check( &self, - tolerance: &input::Any, - input: &input::Any, + tolerance: &input::internal::Any, + input: &input::internal::Any, before: &serde_json::Value, after: &serde_json::Value, ) -> anyhow::Result { @@ -362,7 +369,7 @@ pub(crate) struct RegisteredTolerance<'a> { /// Helper to capture a `Benchmark::description` call into a `String` via `Display`. struct Capture<'a>( &'a dyn benchmark::internal::Benchmark, - Option<&'a input::Any>, + Option<&'a input::internal::Any>, ); impl std::fmt::Display for Capture<'_> { @@ -427,7 +434,7 @@ mod tests { { let a = registry._input(A::tag()).unwrap(); - assert!(a.as_any().is::>()); + assert!(a.as_any().is::>()); let name = a.type_name(); assert!(name.contains("A"), "{}", name); @@ -435,7 +442,7 @@ mod tests { { let b = registry._input(B::tag()).unwrap(); - assert!(b.as_any().is::>()); + assert!(b.as_any().is::>()); let name = b.type_name(); assert!(name.contains("B"), "{}", name); diff --git a/diskann-benchmark/README.md b/diskann-benchmark/README.md index 923bb9e27..b48fdf18a 100644 --- a/diskann-benchmark/README.md +++ b/diskann-benchmark/README.md @@ -206,20 +206,20 @@ this is usually easily done with a small code change. With the example of adding Range search to the `f16` index, the registration site: ```rust -benchmarks.register( +registry.register( "async-full-precision-f16", FullPrecision::::new() .search(plugins::Topk), -); +)?; ``` Can be updated to: ```rust -benchmarks.register( +registry.register( "async-full-precision-f16", FullPrecision::::new() .search(plugins::Topk) .search(plugins::Range), -); +)?; ``` This will both compile the range search implementation and make it available for benchmark matching. @@ -337,53 +337,49 @@ pub(crate) struct ComputeGroundTruth { pub(crate) num_nearest_neighbors: usize, } ``` -We need to implement a few traits related to this input type: +We need to implement `diskann_benchmark_runner::Input` for the type. This trait associates +a tag name used for deserialization and benchmark matching, a `Raw` type for JSON +serialization/deserialization, a `from_raw` constructor that performs post-deserialization +validation (e.g., resolving file paths via the `Checker`), and an `example` that supplies +sample JSON layouts for the CLI. -* `diskann_benchmark_runner::Input`: A type-name for this input that is used to identify it for - deserialization and benchmark matching. To make this easier, `benchmark` defines - `benchmark::inputs::Input` that can be used to express type level implementation (shown - below) - -* `CheckDeserialization`: This trait performs post-deserialization invariant checking. - In the context of the `ComputeGroundTruth` type, we use this to check that the input - files are valid. +In the context of the `ComputeGroundTruth` type, we use `from_raw` to check that the input +files are valid. ```rust -impl diskann_benchmark_runner::Input for crate::inputs::Input { +impl diskann_benchmark_runner::Input for ComputeGroundTruth { + // The raw form is just `Self` since the struct is directly deserializable. + type Raw = Self; + // This gets associated with the JSON representation returned by `example` and at run - // time, inputs tagged with this value will be given to `try_deserialize`. + // time, inputs tagged with this value will be given to `from_raw`. fn tag() -> &'static str { "compute_groundtruth" } - // Attempt to deserialize `Self` from raw JSON. - // - // Implementors can assume that `serialized` looks similar in structure to what is - // returned by `example`. - fn try_deserialize( - serialized: &serde_json::Value, + // Construct from the raw deserialized form, performing file path resolution. + fn from_raw( + mut raw: Self::Raw, checker: &mut diskann_benchmark_runner::Checker, - ) -> anyhow::Result { - checker.any(ComputeGroundTruth::deserialize(serialized)?) + ) -> anyhow::Result { + raw.data.resolve(checker)?; + raw.queries.resolve(checker)?; + Ok(raw) } - // Return a serialized representation of `self` to help users create an input file. - fn example() -> anyhow::Result { - serde_json::to_value(Self { + // Serialize `self` to JSON. + fn serialize(&self) -> anyhow::Result { + Ok(serde_json::to_value(self)?) + } + + // Return an example input to help users create an input file. + fn example() -> Self { + Self { data_type: DataType::Float32, data: InputFile::new("path/to/data"), queries: InputFile::new("path/to/queries"), num_nearest_neighbors: 100, - }) - } -} - -impl CheckDeserialization for ComputeGroundTruth { - fn check_deserialization(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { - // Forward the deserialization check to the input files. - self.data.check_deserialization(checker)?; - self.queries.check_deserialization(checker)?; - Ok(()) + } } } ``` @@ -409,8 +405,8 @@ To implement benchmarks, we register them with the `diskann_benchmark_runner::Re The simplest thing we can do is something like this: ```rust use diskann_benchmark_runner::{ - dispatcher::{MatchScore, FailureScore}, - Any, Benchmark, Checkpoint, Output, Registry, + benchmark::{MatchScore, FailureScore}, + Benchmark, Checkpoint, Output, }; // Benchmarks can be stateful. @@ -429,6 +425,15 @@ impl Benchmark for RunGroundTruth { Ok(MatchScore::new(0)) } + // Describe the benchmark for CLI display and debugging. + fn description( + &self, + f: &mut std::fmt::Formatter<'_>, + _input: Option<&Self::Input>, + ) -> std::fmt::Result { + write!(f, "compute groundtruth") + } + // Run the benchmark (for this example, nothing happens). fn run( &self, diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 9135569e8..9df194382 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -1232,6 +1232,7 @@ impl DynamicIndexRun { self.search_phase.validate(checker)?; Ok(()) } + pub(crate) fn try_as_config(&self, insert_l: usize) -> anyhow::Result { let mut builder = self.build.try_as_config()?; builder.l_build(insert_l); diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index 89041614e..8c9c58dda 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -14,6 +14,10 @@ pub(crate) trait Example { fn example() -> Self; } +/// Implement [`diskann_benchmark_runner::Input`] for `$T` using `Raw = $T`. +/// +/// Requires `$T` to implement [`Example`] and provide a +/// `fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()>` method. macro_rules! as_input { ($T:ty) => { impl diskann_benchmark_runner::Input for $T { From 331a6aca5e085765249b30091f36ec8bf09df3b3 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 17:56:48 -0700 Subject: [PATCH 7/9] Don't change UX tests. --- diskann-benchmark-runner/src/test/typed.rs | 3 ++- .../tests/benchmark/test-mismatch-0/stdout.txt | 3 ++- .../tests/benchmark/test-mismatch-1/stdout.txt | 3 ++- .../tests/benchmark/test-overload-0/output.json | 6 ++++-- .../tests/benchmark/test-success-0/output.json | 6 ++++-- .../tests/regression/check-run-error-0/output.json | 3 ++- .../tests/regression/check-run-error-2/output.json | 3 ++- .../tests/regression/check-run-fail-0/output.json | 3 ++- .../tests/regression/check-run-pass-0/output.json | 6 ++++-- 9 files changed, 24 insertions(+), 12 deletions(-) diff --git a/diskann-benchmark-runner/src/test/typed.rs b/diskann-benchmark-runner/src/test/typed.rs index f3c0c3885..0f5a54201 100644 --- a/diskann-benchmark-runner/src/test/typed.rs +++ b/diskann-benchmark-runner/src/test/typed.rs @@ -21,6 +21,7 @@ use crate::{ pub(crate) struct TypeInput { pub(super) data_type: DataType, pub(super) dim: usize, + error_when_checked: bool, } #[derive(Serialize, Deserialize)] @@ -33,7 +34,7 @@ pub(crate) struct TypeInputRaw { impl TypeInput { pub(crate) fn new(data_type: DataType, dim: usize) -> Self { - Self { data_type, dim } + Self { data_type, dim, error_when_checked: false } } fn run(&self) -> &'static str { diff --git a/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt b/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt index 7e26fb341..ba72e9bbf 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt +++ b/diskann-benchmark-runner/tests/benchmark/test-mismatch-0/stdout.txt @@ -2,7 +2,8 @@ Could not find a match for the following input: { "data_type": "float16", - "dim": 128 + "dim": 128, + "error_when_checked": false } Closest matches: diff --git a/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt b/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt index 85da0f389..34be87554 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt +++ b/diskann-benchmark-runner/tests/benchmark/test-mismatch-1/stdout.txt @@ -2,7 +2,8 @@ Could not find a match for the following input: { "data_type": "float16", - "dim": 1000 + "dim": 1000, + "error_when_checked": false } Closest matches: diff --git a/diskann-benchmark-runner/tests/benchmark/test-overload-0/output.json b/diskann-benchmark-runner/tests/benchmark/test-overload-0/output.json index d53fdfdd2..8fdbaa7e4 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-overload-0/output.json +++ b/diskann-benchmark-runner/tests/benchmark/test-overload-0/output.json @@ -3,7 +3,8 @@ "input": { "content": { "data_type": "float32", - "dim": 1000 + "dim": 1000, + "error_when_checked": false }, "type": "test-input-types" }, @@ -13,7 +14,8 @@ "input": { "content": { "data_type": "float32", - "dim": 128 + "dim": 128, + "error_when_checked": false }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/benchmark/test-success-0/output.json b/diskann-benchmark-runner/tests/benchmark/test-success-0/output.json index 398b4f358..5b15f5ac2 100644 --- a/diskann-benchmark-runner/tests/benchmark/test-success-0/output.json +++ b/diskann-benchmark-runner/tests/benchmark/test-success-0/output.json @@ -21,7 +21,8 @@ "input": { "content": { "data_type": "float32", - "dim": 128 + "dim": 128, + "error_when_checked": false }, "type": "test-input-types" }, @@ -31,7 +32,8 @@ "input": { "content": { "data_type": "int8", - "dim": 128 + "dim": 128, + "error_when_checked": false }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/regression/check-run-error-0/output.json b/diskann-benchmark-runner/tests/regression/check-run-error-0/output.json index f79fc4729..9bad43329 100644 --- a/diskann-benchmark-runner/tests/regression/check-run-error-0/output.json +++ b/diskann-benchmark-runner/tests/regression/check-run-error-0/output.json @@ -12,7 +12,8 @@ "input": { "content": { "data_type": "int8", - "dim": 128 + "dim": 128, + "error_when_checked": false }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/regression/check-run-error-2/output.json b/diskann-benchmark-runner/tests/regression/check-run-error-2/output.json index 4f93315f1..869ae3cc8 100644 --- a/diskann-benchmark-runner/tests/regression/check-run-error-2/output.json +++ b/diskann-benchmark-runner/tests/regression/check-run-error-2/output.json @@ -3,7 +3,8 @@ "input": { "content": { "data_type": "float32", - "dim": 128 + "dim": 128, + "error_when_checked": false }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/regression/check-run-fail-0/output.json b/diskann-benchmark-runner/tests/regression/check-run-fail-0/output.json index 77c4fa441..5112717a9 100644 --- a/diskann-benchmark-runner/tests/regression/check-run-fail-0/output.json +++ b/diskann-benchmark-runner/tests/regression/check-run-fail-0/output.json @@ -12,7 +12,8 @@ "input": { "content": { "data_type": "int8", - "dim": 128 + "dim": 128, + "error_when_checked": false }, "type": "test-input-types" }, diff --git a/diskann-benchmark-runner/tests/regression/check-run-pass-0/output.json b/diskann-benchmark-runner/tests/regression/check-run-pass-0/output.json index 540e94d1b..5386e9dee 100644 --- a/diskann-benchmark-runner/tests/regression/check-run-pass-0/output.json +++ b/diskann-benchmark-runner/tests/regression/check-run-pass-0/output.json @@ -12,7 +12,8 @@ "input": { "content": { "data_type": "float32", - "dim": 1000 + "dim": 1000, + "error_when_checked": false }, "type": "test-input-types" }, @@ -22,7 +23,8 @@ "input": { "content": { "data_type": "int8", - "dim": 128 + "dim": 128, + "error_when_checked": false }, "type": "test-input-types" }, From 100b35c5fdacf3de873e89a07869beaf0d590d9d Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 18:04:28 -0700 Subject: [PATCH 8/9] Run fmt. --- diskann-benchmark-runner/src/test/typed.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/diskann-benchmark-runner/src/test/typed.rs b/diskann-benchmark-runner/src/test/typed.rs index 0f5a54201..f737f875c 100644 --- a/diskann-benchmark-runner/src/test/typed.rs +++ b/diskann-benchmark-runner/src/test/typed.rs @@ -34,7 +34,11 @@ pub(crate) struct TypeInputRaw { impl TypeInput { pub(crate) fn new(data_type: DataType, dim: usize) -> Self { - Self { data_type, dim, error_when_checked: false } + Self { + data_type, + dim, + error_when_checked: false, + } } fn run(&self) -> &'static str { From f1011321f08bd562fce88ce27c46924b9dd47816 Mon Sep 17 00:00:00 2001 From: Mark Hildebrand Date: Fri, 15 May 2026 18:17:41 -0700 Subject: [PATCH 9/9] Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-benchmark/src/inputs/exhaustive.rs | 2 +- diskann-benchmark/src/inputs/mod.rs | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/diskann-benchmark/src/inputs/exhaustive.rs b/diskann-benchmark/src/inputs/exhaustive.rs index 14fd06336..20583de85 100644 --- a/diskann-benchmark/src/inputs/exhaustive.rs +++ b/diskann-benchmark/src/inputs/exhaustive.rs @@ -221,7 +221,7 @@ impl Product { self.data.resolve(checker)?; self.search.validate(checker)?; - // Chcck that provided data type is compatible with `f32`. + // Check that provided data type is compatible with `f32`. f32::check_converting_load(self.data_type)?; let num_centers = self.num_pq_centers.get(); diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index 8c9c58dda..492f0b9c1 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -16,8 +16,13 @@ pub(crate) trait Example { /// Implement [`diskann_benchmark_runner::Input`] for `$T` using `Raw = $T`. /// -/// Requires `$T` to implement [`Example`] and provide a -/// `fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()>` method. +/// Requires `$T` to: +/// - implement [`Example`]; +/// - provide an inherent `fn tag() -> &'static str` method; +/// - provide a +/// `fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()>` method; and +/// - implement the serde traits required by +/// [`diskann_benchmark_runner::Input`] and `serde_json::to_value(self)`. macro_rules! as_input { ($T:ty) => { impl diskann_benchmark_runner::Input for $T {