diff --git a/datafusion-examples/examples/custom_data_source/custom_datasource.rs b/datafusion-examples/examples/custom_data_source/custom_datasource.rs index 937452a286b90..a67738520b010 100644 --- a/datafusion-examples/examples/custom_data_source/custom_datasource.rs +++ b/datafusion-examples/examples/custom_data_source/custom_datasource.rs @@ -26,6 +26,7 @@ use async_trait::async_trait; use datafusion::arrow::array::{UInt8Builder, UInt64Builder}; use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::common::assert_batches_eq; use datafusion::datasource::{TableProvider, TableType, provider_as_source}; use datafusion::error::Result; use datafusion::execution::context::TaskContext; @@ -52,6 +53,33 @@ pub async fn custom_datasource() -> Result<()> { search_accounts(db.clone(), Some(col("bank_account").gt(lit(8000u64))), 1).await?; search_accounts(db.clone(), Some(col("bank_account").gt(lit(200u64))), 2).await?; + // exercise SQL paths that push down non-trivial projections: + // - `SELECT 1 ...` requests no source columns (projection: Some([])) + // - `SELECT COUNT(id) ...` requests a single column (projection: Some([0])) + let ctx = SessionContext::new(); + ctx.register_table("accounts", Arc::new(db))?; + let constant_batches = ctx + .sql("SELECT 1 AS a FROM accounts") + .await? + .collect() + .await?; + assert_batches_eq!( + [ + "+---+", "| a |", "+---+", "| 1 |", "| 1 |", "| 1 |", "+---+", + ], + &constant_batches + ); + + let count_batches = ctx + .sql("SELECT COUNT(id) AS cnt FROM accounts") + .await? + .collect() + .await?; + assert_batches_eq!( + ["+-----+", "| cnt |", "+-----+", "| 3 |", "+-----+",], + &count_batches + ); + Ok(()) } @@ -186,6 +214,7 @@ impl TableProvider for CustomDataSource { #[derive(Debug, Clone)] struct CustomExec { db: CustomDataSource, + projection: Option>, projected_schema: SchemaRef, cache: Arc, } @@ -201,6 +230,7 @@ impl CustomExec { let cache = Self::compute_properties(projected_schema.clone()); Self { db, + projection: projections.cloned(), projected_schema, cache: Arc::new(cache), } @@ -262,15 +292,25 @@ impl ExecutionPlan for CustomExec { account_array.append_value(user.bank_account); } + // Build a batch holding every column the table can produce, then let + // Arrow drop the columns the query didn't ask for. `RecordBatch::project` + // preserves the row count, which matters when the projection selects + // zero columns (e.g. `SELECT 1 FROM t`). + let full_batch = RecordBatch::try_new( + self.db.schema(), + vec![ + Arc::new(id_array.finish()), + Arc::new(account_array.finish()), + ], + )?; + let batch = match &self.projection { + Some(indices) => full_batch.project(indices)?, + None => full_batch, + }; + Ok(Box::pin(MemoryStream::try_new( - vec![RecordBatch::try_new( - self.projected_schema.clone(), - vec![ - Arc::new(id_array.finish()), - Arc::new(account_array.finish()), - ], - )?], - self.schema(), + vec![batch], + self.projected_schema.clone(), None, )?)) }