diff --git a/Cargo.toml b/Cargo.toml index e7d2fab..8083b38 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ] pg15 = ["pgrx/pg15", "pgrx-tests/pg15" ] pg16 = ["pgrx/pg16", "pgrx-tests/pg16" ] pg17 = ["pgrx/pg17", "pgrx-tests/pg17" ] -pg_test = ["tempfile"] +pg_test = ["tempfile", "sqllogictest"] [dependencies] pgrx = "=0.14.3" @@ -34,10 +34,12 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" chrono = "0.4" base64 = "0.22" +sqllogictest = { version = "=0.29.0", optional = true } [dev-dependencies] pgrx-tests = "=0.14.3" tempfile = "3.8" +sqllogictest = "=0.29.0" [dependencies.tempfile] version = "3.8" diff --git a/src/lib.rs b/src/lib.rs index e855523..5920eba 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,124 +48,7 @@ fn lance_import( } #[cfg(any(test, feature = "pg_test"))] -#[pg_schema] -mod tests { - use arrow::array::{ - Array, BooleanArray, Float32Array, Int32Array, ListBuilder, StringArray, StructArray, - }; - use arrow::datatypes::{DataType, Field, Schema}; - use arrow::record_batch::RecordBatch; - use lance_rs::Dataset; - use pgrx::prelude::*; - use std::sync::Arc; - use tempfile::TempDir; - - struct LanceTestDataGenerator { - temp_dir: TempDir, - } - - impl LanceTestDataGenerator { - fn new() -> Result> { - Ok(Self { - temp_dir: TempDir::new()?, - }) - } - - fn create_table_with_struct_and_list( - &self, - ) -> Result> { - let table_path = self.temp_dir.path().join("fdw_table"); - - let id_array = Int32Array::from(vec![1, 2, 3]); - let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie"]); - let active_array = BooleanArray::from(vec![true, false, true]); - - let mut emb_builder = ListBuilder::new(arrow::array::Float32Builder::new()); - for embedding in [ - vec![0.1, 0.2, 0.3], - vec![0.4, 0.5, 0.6], - vec![0.7, 0.8, 0.9], - ] { - for v in embedding { - emb_builder.values().append_value(v); - } - emb_builder.append(true); - } - let emb_array = emb_builder.finish(); - - let meta_score = Float32Array::from(vec![1.0, 2.0, 3.0]); - let meta_tag = StringArray::from(vec!["a", "b", "c"]); - let meta_struct = StructArray::from(vec![ - ( - Arc::new(Field::new("score", DataType::Float32, false)), - Arc::new(meta_score) as _, - ), - ( - Arc::new(Field::new("tag", DataType::Utf8, false)), - Arc::new(meta_tag) as _, - ), - ]); - - let schema = Arc::new(Schema::new(vec![ - Field::new("id", DataType::Int32, false), - Field::new("name", DataType::Utf8, false), - Field::new("active", DataType::Boolean, false), - Field::new( - "embedding", - DataType::List(Arc::new(Field::new("item", DataType::Float32, true))), - false, - ), - Field::new("meta", meta_struct.data_type().clone(), false), - ])); - - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(id_array), - Arc::new(name_array), - Arc::new(active_array), - Arc::new(emb_array), - Arc::new(meta_struct), - ], - )?; - - let reader = arrow::record_batch::RecordBatchIterator::new(vec![Ok(batch)], schema); - let rt = tokio::runtime::Runtime::new().unwrap(); - rt.block_on(async { - Dataset::write(reader, table_path.to_str().unwrap(), None).await - })?; - - Ok(table_path) - } - } - - #[pg_test] - fn test_fdw_import_and_scan() { - let gen = LanceTestDataGenerator::new().expect("generator"); - let path = gen - .create_table_with_struct_and_list() - .expect("create table"); - let uri = path.to_str().unwrap(); - - Spi::run("CREATE SERVER lance_srv FOREIGN DATA WRAPPER lance_fdw").expect("create server"); - - let import_sql = format!( - "SELECT lance_import('lance_srv', 'public', 't_fdw', '{}', NULL)", - uri.replace('\'', "''") - ); - Spi::run(&import_sql).expect("lance_import"); - - let cnt = Spi::get_one::("SELECT count(*) FROM public.t_fdw") - .expect("count") - .expect("count value"); - assert_eq!(cnt, 3); - - let v = Spi::get_one::("SELECT name FROM public.t_fdw WHERE id = 2") - .expect("select") - .expect("value"); - assert_eq!(v, "Bob"); - } -} +mod tests; #[cfg(test)] pub mod pg_test { diff --git a/src/tests.rs b/src/tests.rs new file mode 100644 index 0000000..cd457cf --- /dev/null +++ b/src/tests.rs @@ -0,0 +1,305 @@ +use pgrx::prelude::*; + +#[pg_schema] +mod tests { + use super::*; + use arrow::array::{ + Array, BooleanArray, Float32Array, Int32Array, ListBuilder, StringArray, StructArray, + }; + use arrow::datatypes::{DataType, Field, Schema}; + use arrow::record_batch::RecordBatch; + use lance_rs::Dataset; + use sqllogictest::{DBOutput, DefaultColumnType, Runner}; + use std::ffi::{CStr, OsStr}; + use std::fs; + use std::path::{Path, PathBuf}; + use std::sync::Arc; + use tempfile::TempDir; + + struct SpiSltDb; + + impl sqllogictest::DB for SpiSltDb { + type Error = pgrx::spi::SpiError; + type ColumnType = DefaultColumnType; + + fn run(&mut self, sql: &str) -> Result, Self::Error> { + let sql = sql.trim(); + if sql.is_empty() { + return Ok(DBOutput::StatementComplete(0)); + } + + Spi::connect_mut(|client| { + let mut tuptable = client.update(sql, None, &[])?; + + let columns = match tuptable.columns() { + Ok(columns) => columns, + Err(pgrx::spi::SpiError::NoTupleTable) => { + return Ok(DBOutput::StatementComplete(tuptable.len() as u64)); + } + Err(e) => return Err(e), + }; + + let mut types = Vec::with_capacity(columns); + let mut type_oids = Vec::with_capacity(columns); + for i in 1..=columns { + let oid = tuptable.column_type_oid(i)?.value(); + types.push(map_pg_oid_to_slt(oid)); + type_oids.push(oid); + } + + let mut rows = Vec::new(); + while tuptable.next().is_some() { + let mut row = Vec::with_capacity(columns); + for (idx, oid) in type_oids.iter().enumerate() { + let datum = tuptable.get_datum_by_ordinal(idx + 1)?; + row.push(format_pg_datum(datum, *oid)); + } + rows.push(row); + } + + Ok(DBOutput::Rows { types, rows }) + }) + } + + fn engine_name(&self) -> &str { + "postgres" + } + } + + fn map_pg_oid_to_slt(oid: pg_sys::Oid) -> DefaultColumnType { + match PgOid::from_untagged(oid) { + PgOid::BuiltIn(builtin) => match builtin { + pg_sys::BuiltinOid::INT2OID + | pg_sys::BuiltinOid::INT4OID + | pg_sys::BuiltinOid::INT8OID + | pg_sys::BuiltinOid::OIDOID => DefaultColumnType::Integer, + pg_sys::BuiltinOid::FLOAT4OID + | pg_sys::BuiltinOid::FLOAT8OID + | pg_sys::BuiltinOid::NUMERICOID => DefaultColumnType::FloatingPoint, + _ => DefaultColumnType::Text, + }, + _ => DefaultColumnType::Text, + } + } + + fn format_pg_datum(datum: Option, type_oid: pg_sys::Oid) -> String { + match datum { + None => "NULL".to_string(), + Some(datum) => unsafe { + let mut out_func = pg_sys::Oid::from(0u32); + let mut is_varlena = false; + pg_sys::getTypeOutputInfo(type_oid, &mut out_func, &mut is_varlena); + + let ptr = pg_sys::OidOutputFunctionCall(out_func, datum); + let s = CStr::from_ptr(ptr) + .to_str() + .unwrap_or("") + .to_string(); + pg_sys::pfree(ptr as *mut _); + s + }, + } + } + + fn slt_identifier(input: &str) -> String { + let mut out = String::with_capacity(input.len()); + for ch in input.chars() { + if ch.is_ascii_alphanumeric() { + out.push(ch.to_ascii_lowercase()); + } else { + out.push('_'); + } + } + if out.is_empty() { + out.push('_'); + } + let first = out.as_bytes()[0]; + if !first.is_ascii_alphabetic() && first != b'_' { + out.insert(0, '_'); + } + if out.len() > 50 { + out.truncate(50); + } + out + } + + fn list_slt_files(dir: &Path) -> Vec { + let mut files: Vec = fs::read_dir(dir) + .expect("read_dir tests/sql") + .filter_map(|entry| entry.ok()) + .map(|entry| entry.path()) + .filter(|path| path.extension().is_some_and(|ext| ext == OsStr::new("slt"))) + .collect(); + files.sort(); + files + } + + struct LanceTestDataGenerator { + temp_dir: TempDir, + } + + impl LanceTestDataGenerator { + fn new() -> Result> { + Ok(Self { + temp_dir: TempDir::new()?, + }) + } + + fn create_table_with_struct_and_list( + &self, + ) -> Result> { + let table_path = self.temp_dir.path().join("fdw_table"); + + let id_array = Int32Array::from(vec![1, 2, 3]); + let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie"]); + let active_array = BooleanArray::from(vec![true, false, true]); + + let mut emb_builder = ListBuilder::new(arrow::array::Float32Builder::new()); + for embedding in [ + vec![0.1, 0.2, 0.3], + vec![0.4, 0.5, 0.6], + vec![0.7, 0.8, 0.9], + ] { + for v in embedding { + emb_builder.values().append_value(v); + } + emb_builder.append(true); + } + let emb_array = emb_builder.finish(); + + let meta_score = Float32Array::from(vec![1.0, 2.0, 3.0]); + let meta_tag = StringArray::from(vec!["a", "b", "c"]); + let meta_struct = StructArray::from(vec![ + ( + Arc::new(Field::new("score", DataType::Float32, false)), + Arc::new(meta_score) as _, + ), + ( + Arc::new(Field::new("tag", DataType::Utf8, false)), + Arc::new(meta_tag) as _, + ), + ]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, false), + Field::new("active", DataType::Boolean, false), + Field::new( + "embedding", + DataType::List(Arc::new(Field::new("item", DataType::Float32, true))), + false, + ), + Field::new("meta", meta_struct.data_type().clone(), false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(id_array), + Arc::new(name_array), + Arc::new(active_array), + Arc::new(emb_array), + Arc::new(meta_struct), + ], + )?; + + let reader = arrow::record_batch::RecordBatchIterator::new(vec![Ok(batch)], schema); + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + Dataset::write(reader, table_path.to_str().unwrap(), None).await + })?; + + Ok(table_path) + } + } + + #[pg_test] + fn test_fdw_import_and_scan() { + Spi::run("SELECT pg_advisory_lock(424242)").expect("advisory lock"); + + let gen = LanceTestDataGenerator::new().expect("generator"); + let path = gen + .create_table_with_struct_and_list() + .expect("create table"); + let uri = path.to_str().unwrap(); + + Spi::run("DROP SCHEMA IF EXISTS slt_unit CASCADE").expect("drop schema"); + Spi::run("CREATE SCHEMA slt_unit").expect("create schema"); + Spi::run("SET search_path TO slt_unit, public").expect("set search_path"); + Spi::run("DROP SERVER IF EXISTS lance_srv_unit CASCADE").expect("drop server"); + Spi::run("CREATE SERVER lance_srv_unit FOREIGN DATA WRAPPER lance_fdw") + .expect("create server"); + + let import_sql = format!( + "SELECT lance_import('lance_srv_unit', 'slt_unit', 't_fdw', '{}', NULL)", + uri.replace('\'', "''") + ); + Spi::run(&import_sql).expect("lance_import"); + + let cnt = Spi::get_one::("SELECT count(*) FROM slt_unit.t_fdw") + .expect("count") + .expect("count value"); + assert_eq!(cnt, 3); + + let v = Spi::get_one::("SELECT name FROM slt_unit.t_fdw WHERE id = 2") + .expect("select") + .expect("value"); + assert_eq!(v, "Bob"); + + Spi::run("SELECT pg_advisory_unlock(424242)").expect("advisory unlock"); + } + + #[pg_test] + fn test_sqllogictest() { + Spi::run("SELECT pg_advisory_lock(424242)").expect("advisory lock"); + + let gen = LanceTestDataGenerator::new().expect("generator"); + let path = gen + .create_table_with_struct_and_list() + .expect("create table"); + let uri = path.to_str().expect("uri").replace('\'', "''"); + + let scripts_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("tests/sql"); + let slt_files = list_slt_files(&scripts_dir); + assert!( + !slt_files.is_empty(), + "no .slt files found under {}", + scripts_dir.display() + ); + + for (idx, file) in slt_files.iter().enumerate() { + let stem = file + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("unknown"); + let schema = format!("slt_{}_{}", idx, slt_identifier(stem)); + let server = format!("{}_srv", schema); + + let mut script = fs::read_to_string(file).expect("read .slt file"); + script = script.replace("${LANCE_URI}", &uri); + script = script.replace("${SCHEMA}", &schema); + script = script.replace("${SERVER}", &server); + + let prefix = format!( + "statement ok\n\ +DROP SCHEMA IF EXISTS {schema} CASCADE;\n\n\ +statement ok\n\ +CREATE SCHEMA {schema};\n\n\ +statement ok\n\ +SET search_path TO {schema}, public;\n\n\ +statement ok\n\ +DROP SERVER IF EXISTS {server} CASCADE;\n\n\ +statement ok\n\ +CREATE SERVER {server} FOREIGN DATA WRAPPER lance_fdw;\n\n" + ); + let full_script = format!("{prefix}\n{script}\n"); + + let mut runner = Runner::new(|| async { Ok::<_, pgrx::spi::SpiError>(SpiSltDb) }); + if let Err(e) = runner.run_script_with_name(&full_script, file.display().to_string()) { + panic!("{}", e.display(false)); + } + } + + Spi::run("SELECT pg_advisory_unlock(424242)").expect("advisory unlock"); + } +} diff --git a/tests/sql/00_smoke.slt b/tests/sql/00_smoke.slt new file mode 100644 index 0000000..72d76f7 --- /dev/null +++ b/tests/sql/00_smoke.slt @@ -0,0 +1,21 @@ +# Basic smoke test for lance FDW. + +statement ok +SELECT lance_import('${SERVER}', '${SCHEMA}', 't_fdw', '${LANCE_URI}', NULL); + +query I +SELECT count(*) FROM t_fdw; +---- +3 + +query T +SELECT name FROM t_fdw WHERE id = 2; +---- +Bob + +query I +SELECT array_length(embedding, 1) FROM t_fdw ORDER BY id; +---- +3 +3 +3