diff --git a/src/fdw/convert.rs b/src/fdw/convert.rs index 4cfe047..f379b25 100644 --- a/src/fdw/convert.rs +++ b/src/fdw/convert.rs @@ -19,206 +19,556 @@ use serde_json::{Map, Number, Value}; const UNIX_TO_POSTGRES_EPOCH_SECS: i64 = 946_684_800; +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConvertErrorKind { + TypeMismatch, + UnsupportedType, + ValueOutOfRange, + Internal, +} + +#[derive(Debug)] +pub struct ConvertError { + pub kind: ConvertErrorKind, + pub message: String, +} + +impl ConvertError { + fn type_mismatch(message: impl Into) -> Self { + Self { + kind: ConvertErrorKind::TypeMismatch, + message: message.into(), + } + } + + fn unsupported_type(message: impl Into) -> Self { + Self { + kind: ConvertErrorKind::UnsupportedType, + message: message.into(), + } + } + + fn value_out_of_range(message: impl Into) -> Self { + Self { + kind: ConvertErrorKind::ValueOutOfRange, + message: message.into(), + } + } + + fn internal(message: impl Into) -> Self { + Self { + kind: ConvertErrorKind::Internal, + message: message.into(), + } + } +} + +impl std::fmt::Display for ConvertError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.message) + } +} + +fn require_target_oid( + target_type_oid: pg_sys::Oid, + expected_oid: pg_sys::Oid, + message: &'static str, +) -> Result<(), ConvertError> { + if target_type_oid != expected_oid { + return Err(ConvertError::type_mismatch(message)); + } + Ok(()) +} + +pub fn validate_arrow_type_for_pg_oid( + arrow_type: &DataType, + target_type_oid: pg_sys::Oid, +) -> Result<(), ConvertError> { + match arrow_type { + DataType::Boolean => require_target_oid( + target_type_oid, + pg_sys::BOOLOID, + "arrow boolean requires postgres boolean", + ), + DataType::Int8 | DataType::UInt8 | DataType::Int16 => require_target_oid( + target_type_oid, + pg_sys::INT2OID, + "arrow int8/uint8/int16 requires postgres int2", + ), + DataType::UInt16 | DataType::Int32 => require_target_oid( + target_type_oid, + pg_sys::INT4OID, + "arrow uint16/int32 requires postgres int4", + ), + DataType::UInt32 | DataType::Int64 | DataType::UInt64 => require_target_oid( + target_type_oid, + pg_sys::INT8OID, + "arrow uint32/int64/uint64 requires postgres int8", + ), + DataType::Float16 | DataType::Float32 => require_target_oid( + target_type_oid, + pg_sys::FLOAT4OID, + "arrow float16/float32 requires postgres float4", + ), + DataType::Float64 => require_target_oid( + target_type_oid, + pg_sys::FLOAT8OID, + "arrow float64 requires postgres float8", + ), + DataType::Utf8 | DataType::LargeUtf8 => require_target_oid( + target_type_oid, + pg_sys::TEXTOID, + "arrow utf8 requires postgres text", + ), + DataType::Binary | DataType::LargeBinary | DataType::FixedSizeBinary(_) => { + require_target_oid( + target_type_oid, + pg_sys::BYTEAOID, + "arrow binary requires postgres bytea", + ) + } + DataType::Date32 | DataType::Date64 => require_target_oid( + target_type_oid, + pg_sys::DATEOID, + "arrow date requires postgres date", + ), + DataType::Timestamp(_, tz) => { + if tz.is_some() { + require_target_oid( + target_type_oid, + pg_sys::TIMESTAMPTZOID, + "arrow timestamp with timezone requires postgres timestamptz", + ) + } else { + require_target_oid( + target_type_oid, + pg_sys::TIMESTAMPOID, + "arrow timestamp without timezone requires postgres timestamp", + ) + } + } + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => require_target_oid( + target_type_oid, + pg_sys::NUMERICOID, + "arrow decimal requires postgres numeric", + ), + DataType::Dictionary(_, value) => { + validate_arrow_type_for_pg_oid(value.as_ref(), target_type_oid) + } + DataType::List(elem) | DataType::LargeList(elem) | DataType::FixedSizeList(elem, _) => { + if target_type_oid == pg_sys::JSONBOID { + return Ok(()); + } + let elem_oid = unsafe { pg_sys::get_element_type(target_type_oid) }; + if elem_oid == pg_sys::InvalidOid { + return Err(ConvertError::type_mismatch( + "arrow list requires postgres array (or jsonb)", + )); + } + validate_arrow_type_for_pg_oid(elem.data_type(), elem_oid) + } + DataType::Struct(fields) => unsafe { + let tupdesc = + pg_sys::lookup_type_cache(target_type_oid, pg_sys::TYPECACHE_TUPDESC as i32); + if tupdesc.is_null() || (*tupdesc).tupDesc.is_null() { + return Err(ConvertError::type_mismatch( + "arrow struct requires postgres composite type", + )); + } + let tupdesc = (*tupdesc).tupDesc; + let natts = (*tupdesc).natts as usize; + for i in 0..natts { + let attr = *(*tupdesc).attrs.as_ptr().add(i); + if attr.attisdropped { + continue; + } + let name = std::ffi::CStr::from_ptr(attr.attname.data.as_ptr()) + .to_string_lossy() + .to_string(); + let field = fields.iter().find(|f| f.name() == &name).ok_or_else(|| { + ConvertError::type_mismatch(format!("missing struct field: {}", name)) + })?; + validate_arrow_type_for_pg_oid(field.data_type(), attr.atttypid)?; + } + Ok(()) + }, + DataType::Map(_, _) => require_target_oid( + target_type_oid, + pg_sys::JSONBOID, + "arrow map requires postgres jsonb", + ), + _ => require_target_oid( + target_type_oid, + pg_sys::JSONBOID, + "unsupported arrow type requires postgres jsonb", + ), + } +} + pub fn arrow_value_to_datum( array: &dyn Array, row_idx: usize, target_type_oid: pg_sys::Oid, -) -> Result<(pg_sys::Datum, bool), &'static str> { +) -> Result<(pg_sys::Datum, bool), ConvertError> { if array.is_null(row_idx) { return Ok((pg_sys::Datum::from(0usize), true)); } match array.data_type() { DataType::Boolean => { + require_target_oid( + target_type_oid, + pg_sys::BOOLOID, + "arrow boolean requires postgres boolean", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid boolean array")? + .ok_or_else(|| ConvertError::internal("invalid boolean array"))? .value(row_idx); - Ok((v.into_datum().ok_or("failed to convert bool")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert bool"))?, + false, + )) } DataType::Int8 => { + require_target_oid( + target_type_oid, + pg_sys::INT2OID, + "arrow int8 requires postgres int2", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid int8 array")? + .ok_or_else(|| ConvertError::internal("invalid int8 array"))? .value(row_idx) as i16; - Ok((v.into_datum().ok_or("failed to convert int8")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert int8"))?, + false, + )) } DataType::Int16 => { + require_target_oid( + target_type_oid, + pg_sys::INT2OID, + "arrow int16 requires postgres int2", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid int16 array")? + .ok_or_else(|| ConvertError::internal("invalid int16 array"))? .value(row_idx); - Ok((v.into_datum().ok_or("failed to convert int16")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert int16"))?, + false, + )) } DataType::Int32 => { + require_target_oid( + target_type_oid, + pg_sys::INT4OID, + "arrow int32 requires postgres int4", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid int32 array")? + .ok_or_else(|| ConvertError::internal("invalid int32 array"))? .value(row_idx); - Ok((v.into_datum().ok_or("failed to convert int32")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert int32"))?, + false, + )) } DataType::Int64 => { + require_target_oid( + target_type_oid, + pg_sys::INT8OID, + "arrow int64 requires postgres int8", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid int64 array")? + .ok_or_else(|| ConvertError::internal("invalid int64 array"))? .value(row_idx); - Ok((v.into_datum().ok_or("failed to convert int64")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert int64"))?, + false, + )) } DataType::UInt8 => { + require_target_oid( + target_type_oid, + pg_sys::INT2OID, + "arrow uint8 requires postgres int2", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid uint8 array")? + .ok_or_else(|| ConvertError::internal("invalid uint8 array"))? .value(row_idx) as i16; - Ok((v.into_datum().ok_or("failed to convert uint8")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert uint8"))?, + false, + )) } DataType::UInt16 => { + require_target_oid( + target_type_oid, + pg_sys::INT4OID, + "arrow uint16 requires postgres int4", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid uint16 array")? + .ok_or_else(|| ConvertError::internal("invalid uint16 array"))? .value(row_idx) as i32; - Ok((v.into_datum().ok_or("failed to convert uint16")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert uint16"))?, + false, + )) } DataType::UInt32 => { + require_target_oid( + target_type_oid, + pg_sys::INT8OID, + "arrow uint32 requires postgres int8", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid uint32 array")? + .ok_or_else(|| ConvertError::internal("invalid uint32 array"))? .value(row_idx) as i64; - Ok((v.into_datum().ok_or("failed to convert uint32")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert uint32"))?, + false, + )) } DataType::UInt64 => { + require_target_oid( + target_type_oid, + pg_sys::INT8OID, + "arrow uint64 requires postgres int8", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid uint64 array")? + .ok_or_else(|| ConvertError::internal("invalid uint64 array"))? .value(row_idx); if v > i64::MAX as u64 { - return Err("uint64 out of range for int8"); + return Err(ConvertError::value_out_of_range( + "arrow uint64 value out of range for postgres int8", + )); } Ok(( - (v as i64).into_datum().ok_or("failed to convert uint64")?, + (v as i64) + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert uint64"))?, false, )) } DataType::Float16 => { + require_target_oid( + target_type_oid, + pg_sys::FLOAT4OID, + "arrow float16 requires postgres float4", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid float16 array")? + .ok_or_else(|| ConvertError::internal("invalid float16 array"))? .value(row_idx) .to_f32(); - Ok((v.into_datum().ok_or("failed to convert float16")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert float16"))?, + false, + )) } DataType::Float32 => { + require_target_oid( + target_type_oid, + pg_sys::FLOAT4OID, + "arrow float32 requires postgres float4", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid float32 array")? + .ok_or_else(|| ConvertError::internal("invalid float32 array"))? .value(row_idx); - Ok((v.into_datum().ok_or("failed to convert float32")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert float32"))?, + false, + )) } DataType::Float64 => { + require_target_oid( + target_type_oid, + pg_sys::FLOAT8OID, + "arrow float64 requires postgres float8", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid float64 array")? + .ok_or_else(|| ConvertError::internal("invalid float64 array"))? .value(row_idx); - Ok((v.into_datum().ok_or("failed to convert float64")?, false)) + Ok(( + v.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert float64"))?, + false, + )) } DataType::Utf8 => { + require_target_oid( + target_type_oid, + pg_sys::TEXTOID, + "arrow utf8 requires postgres text", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid utf8 array")? + .ok_or_else(|| ConvertError::internal("invalid utf8 array"))? .value(row_idx); Ok(( - v.to_string().into_datum().ok_or("failed to convert text")?, + v.to_string() + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert text"))?, false, )) } DataType::LargeUtf8 => { + require_target_oid( + target_type_oid, + pg_sys::TEXTOID, + "arrow large utf8 requires postgres text", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid large utf8 array")? + .ok_or_else(|| ConvertError::internal("invalid large utf8 array"))? .value(row_idx); Ok(( - v.to_string().into_datum().ok_or("failed to convert text")?, + v.to_string() + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert text"))?, false, )) } DataType::Binary => { + require_target_oid( + target_type_oid, + pg_sys::BYTEAOID, + "arrow binary requires postgres bytea", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid binary array")? + .ok_or_else(|| ConvertError::internal("invalid binary array"))? .value(row_idx); Ok(( - v.to_vec().into_datum().ok_or("failed to convert bytea")?, + v.to_vec() + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert bytea"))?, false, )) } DataType::LargeBinary => { + require_target_oid( + target_type_oid, + pg_sys::BYTEAOID, + "arrow large binary requires postgres bytea", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid large binary array")? + .ok_or_else(|| ConvertError::internal("invalid large binary array"))? .value(row_idx); Ok(( - v.to_vec().into_datum().ok_or("failed to convert bytea")?, + v.to_vec() + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert bytea"))?, false, )) } DataType::FixedSizeBinary(_) => { + require_target_oid( + target_type_oid, + pg_sys::BYTEAOID, + "arrow fixed size binary requires postgres bytea", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid fixed size binary array")? + .ok_or_else(|| ConvertError::internal("invalid fixed size binary array"))? .value(row_idx); Ok(( - v.to_vec().into_datum().ok_or("failed to convert bytea")?, + v.to_vec() + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert bytea"))?, false, )) } DataType::Date32 => { + require_target_oid( + target_type_oid, + pg_sys::DATEOID, + "arrow date32 requires postgres date", + )?; let days = array .as_any() .downcast_ref::() - .ok_or("invalid date32 array")? + .ok_or_else(|| ConvertError::internal("invalid date32 array"))? .value(row_idx); - let base = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).ok_or("invalid epoch")?; + let base = chrono::NaiveDate::from_ymd_opt(1970, 1, 1) + .ok_or_else(|| ConvertError::internal("invalid epoch"))?; let dt = base .checked_add_signed(chrono::Duration::days(days as i64)) - .ok_or("date32 overflow")?; + .ok_or_else(|| ConvertError::value_out_of_range("arrow date32 overflow"))?; let date = Date::new(dt.year(), dt.month() as u8, dt.day() as u8) - .map_err(|_| "invalid date")?; - Ok((date.into_datum().ok_or("failed to convert date")?, false)) + .map_err(|_| ConvertError::internal("invalid date"))?; + Ok(( + date.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert date"))?, + false, + )) } DataType::Date64 => { + require_target_oid( + target_type_oid, + pg_sys::DATEOID, + "arrow date64 requires postgres date", + )?; let millis = array .as_any() .downcast_ref::() - .ok_or("invalid date64 array")? + .ok_or_else(|| ConvertError::internal("invalid date64 array"))? .value(row_idx); - let dt = chrono::DateTime::from_timestamp_millis(millis).ok_or("date64 overflow")?; + let dt = chrono::DateTime::from_timestamp_millis(millis) + .ok_or_else(|| ConvertError::value_out_of_range("arrow date64 overflow"))?; let date = Date::new(dt.year(), dt.month() as u8, dt.day() as u8) - .map_err(|_| "invalid date")?; - Ok((date.into_datum().ok_or("failed to convert date")?, false)) + .map_err(|_| ConvertError::internal("invalid date"))?; + Ok(( + date.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert date"))?, + false, + )) } DataType::Timestamp(unit, tz) => { + validate_arrow_type_for_pg_oid(array.data_type(), target_type_oid)?; let unix_micros = match unit { ArrowTimeUnit::Second => { let secs = array .as_any() .downcast_ref::() - .ok_or("invalid timestamp(s) array")? + .ok_or_else(|| ConvertError::internal("invalid timestamp(s) array"))? .value(row_idx); secs.saturating_mul(1_000_000) } @@ -226,20 +576,20 @@ pub fn arrow_value_to_datum( let millis = array .as_any() .downcast_ref::() - .ok_or("invalid timestamp(ms) array")? + .ok_or_else(|| ConvertError::internal("invalid timestamp(ms) array"))? .value(row_idx); millis.saturating_mul(1_000) } ArrowTimeUnit::Microsecond => array .as_any() .downcast_ref::() - .ok_or("invalid timestamp(us) array")? + .ok_or_else(|| ConvertError::internal("invalid timestamp(us) array"))? .value(row_idx), ArrowTimeUnit::Nanosecond => { let nanos = array .as_any() .downcast_ref::() - .ok_or("invalid timestamp(ns) array")? + .ok_or_else(|| ConvertError::internal("invalid timestamp(ns) array"))? .value(row_idx); nanos / 1_000 } @@ -247,60 +597,89 @@ pub fn arrow_value_to_datum( let pg_micros = unix_micros - UNIX_TO_POSTGRES_EPOCH_SECS.saturating_mul(1_000_000); - if tz.is_some() || target_type_oid == pg_sys::TIMESTAMPTZOID { + if tz.is_some() { let ts = TimestampWithTimeZone::try_from(pg_micros) - .map_err(|_| "invalid timestamptz")?; + .map_err(|_| ConvertError::value_out_of_range("invalid timestamptz"))?; Ok(( - ts.into_datum().ok_or("failed to convert timestamptz")?, + ts.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert timestamptz"))?, false, )) } else { - let ts = Timestamp::try_from(pg_micros).map_err(|_| "invalid timestamp")?; - Ok((ts.into_datum().ok_or("failed to convert timestamp")?, false)) + let ts = Timestamp::try_from(pg_micros) + .map_err(|_| ConvertError::value_out_of_range("invalid timestamp"))?; + Ok(( + ts.into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert timestamp"))?, + false, + )) } } DataType::Decimal128(_, _) => { + require_target_oid( + target_type_oid, + pg_sys::NUMERICOID, + "arrow decimal128 requires postgres numeric", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid decimal128 array")? + .ok_or_else(|| ConvertError::internal("invalid decimal128 array"))? .value_as_string(row_idx); - let numeric = AnyNumeric::try_from(v.as_str()).map_err(|_| "invalid numeric")?; + let numeric = AnyNumeric::try_from(v.as_str()) + .map_err(|_| ConvertError::internal("invalid numeric"))?; Ok(( - numeric.into_datum().ok_or("failed to convert numeric")?, + numeric + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert numeric"))?, false, )) } DataType::Decimal256(_, _) => { + require_target_oid( + target_type_oid, + pg_sys::NUMERICOID, + "arrow decimal256 requires postgres numeric", + )?; let v = array .as_any() .downcast_ref::() - .ok_or("invalid decimal256 array")? + .ok_or_else(|| ConvertError::internal("invalid decimal256 array"))? .value_as_string(row_idx); - let numeric = AnyNumeric::try_from(v.as_str()).map_err(|_| "invalid numeric")?; + let numeric = AnyNumeric::try_from(v.as_str()) + .map_err(|_| ConvertError::internal("invalid numeric"))?; Ok(( - numeric.into_datum().ok_or("failed to convert numeric")?, + numeric + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert numeric"))?, false, )) } DataType::Dictionary(_, _) => { let dict = array .as_any_dictionary_opt() - .ok_or("invalid dictionary array")?; + .ok_or_else(|| ConvertError::internal("invalid dictionary array"))?; let value_idx = dictionary_key_to_usize(dict.keys(), row_idx)?; let values = dict.values().as_ref(); if value_idx >= values.len() { - return Err("dictionary key out of range"); + return Err(ConvertError::internal("dictionary key out of range")); } arrow_value_to_datum(values, value_idx, target_type_oid) } DataType::List(_) | DataType::LargeList(_) => { let elem_oid = unsafe { pg_sys::get_element_type(target_type_oid) }; if elem_oid == pg_sys::InvalidOid { - let json = arrow_value_to_json(array, row_idx)?; - return Ok(( - JsonB(json).into_datum().ok_or("failed to convert jsonb")?, - false, + if target_type_oid == pg_sys::JSONBOID { + let json = arrow_value_to_json(array, row_idx)?; + return Ok(( + JsonB(json) + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert jsonb"))?, + false, + )); + } + return Err(ConvertError::type_mismatch( + "arrow list requires postgres array (or jsonb)", )); } list_to_array_datum(array, row_idx, target_type_oid, elem_oid) @@ -308,10 +687,17 @@ pub fn arrow_value_to_datum( DataType::FixedSizeList(_, _) => { let elem_oid = unsafe { pg_sys::get_element_type(target_type_oid) }; if elem_oid == pg_sys::InvalidOid { - let json = arrow_value_to_json(array, row_idx)?; - return Ok(( - JsonB(json).into_datum().ok_or("failed to convert jsonb")?, - false, + if target_type_oid == pg_sys::JSONBOID { + let json = arrow_value_to_json(array, row_idx)?; + return Ok(( + JsonB(json) + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert jsonb"))?, + false, + )); + } + return Err(ConvertError::type_mismatch( + "arrow fixed size list requires postgres array (or jsonb)", )); } fixed_size_list_to_array_datum(array, row_idx, target_type_oid, elem_oid) @@ -320,20 +706,34 @@ pub fn arrow_value_to_datum( let struct_array = array .as_any() .downcast_ref::() - .ok_or("invalid struct array")?; + .ok_or_else(|| ConvertError::internal("invalid struct array"))?; struct_to_composite_datum(struct_array, row_idx, target_type_oid) } DataType::Map(_, _) => { + require_target_oid( + target_type_oid, + pg_sys::JSONBOID, + "arrow map requires postgres jsonb", + )?; let json = arrow_value_to_json(array, row_idx)?; Ok(( - JsonB(json).into_datum().ok_or("failed to convert jsonb")?, + JsonB(json) + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert jsonb"))?, false, )) } _ => { + require_target_oid( + target_type_oid, + pg_sys::JSONBOID, + "unsupported arrow type requires postgres jsonb", + )?; let json = arrow_value_to_json(array, row_idx)?; Ok(( - JsonB(json).into_datum().ok_or("failed to convert jsonb")?, + JsonB(json) + .into_datum() + .ok_or_else(|| ConvertError::internal("failed to convert jsonb"))?, false, )) } @@ -345,17 +745,17 @@ fn list_to_array_datum( row_idx: usize, array_type_oid: pg_sys::Oid, elem_oid: pg_sys::Oid, -) -> Result<(pg_sys::Datum, bool), &'static str> { +) -> Result<(pg_sys::Datum, bool), ConvertError> { fn handle_list( array: &dyn Array, row_idx: usize, array_type_oid: pg_sys::Oid, elem_oid: pg_sys::Oid, - ) -> Result<(pg_sys::Datum, bool), &'static str> { + ) -> Result<(pg_sys::Datum, bool), ConvertError> { let list_array = array .as_any() .downcast_ref::>() - .ok_or("invalid list array")?; + .ok_or_else(|| ConvertError::internal("invalid list array"))?; let values = list_array.value(row_idx); array_values_to_pg_array(values.as_ref(), array_type_oid, elem_oid) } @@ -363,7 +763,7 @@ fn list_to_array_datum( match array.data_type() { DataType::List(_) => handle_list::(array, row_idx, array_type_oid, elem_oid), DataType::LargeList(_) => handle_list::(array, row_idx, array_type_oid, elem_oid), - _ => Err("not a list array"), + _ => Err(ConvertError::internal("not a list array")), } } @@ -372,11 +772,11 @@ fn fixed_size_list_to_array_datum( row_idx: usize, array_type_oid: pg_sys::Oid, elem_oid: pg_sys::Oid, -) -> Result<(pg_sys::Datum, bool), &'static str> { +) -> Result<(pg_sys::Datum, bool), ConvertError> { let list_array = array .as_any() .downcast_ref::() - .ok_or("invalid fixed size list array")?; + .ok_or_else(|| ConvertError::internal("invalid fixed size list array"))?; let values = list_array.value(row_idx); array_values_to_pg_array(values.as_ref(), array_type_oid, elem_oid) } @@ -385,7 +785,7 @@ fn array_values_to_pg_array( values: &dyn Array, array_type_oid: pg_sys::Oid, elem_oid: pg_sys::Oid, -) -> Result<(pg_sys::Datum, bool), &'static str> { +) -> Result<(pg_sys::Datum, bool), ConvertError> { let _ = array_type_oid; let len = values.len(); let mut datums = Vec::with_capacity(len); @@ -417,7 +817,7 @@ fn array_values_to_pg_array( typalign, ); if arr.is_null() { - return Err("failed to construct array"); + return Err(ConvertError::internal("failed to construct array")); } let datum = pg_sys::Datum::from(arr); @@ -429,12 +829,14 @@ fn struct_to_composite_datum( struct_array: &StructArray, row_idx: usize, composite_type_oid: pg_sys::Oid, -) -> Result<(pg_sys::Datum, bool), &'static str> { +) -> Result<(pg_sys::Datum, bool), ConvertError> { unsafe { let tupdesc = pg_sys::lookup_type_cache(composite_type_oid, pg_sys::TYPECACHE_TUPDESC as i32); if tupdesc.is_null() || (*tupdesc).tupDesc.is_null() { - return Err("failed to lookup composite tupdesc"); + return Err(ConvertError::type_mismatch( + "failed to lookup composite tupdesc", + )); } let tupdesc = (*tupdesc).tupDesc; @@ -458,7 +860,9 @@ fn struct_to_composite_datum( .fields() .iter() .position(|f| f.name() == &name) - .ok_or("struct field name mismatch")?; + .ok_or_else(|| { + ConvertError::type_mismatch(format!("struct field name mismatch: {}", name)) + })?; let col = struct_array.column(idx); let (d, isnull) = arrow_value_to_datum(col.as_ref(), row_idx, attr.atttypid)?; values[i] = d; @@ -467,96 +871,87 @@ fn struct_to_composite_datum( let htup = pg_sys::heap_form_tuple(tupdesc, values.as_mut_ptr(), nulls.as_mut_ptr()); if htup.is_null() { - return Err("failed to form heap tuple"); + return Err(ConvertError::internal("failed to form heap tuple")); } let datum = pg_sys::HeapTupleGetDatum(htup); Ok((datum, false)) } } -fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result { +fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result { if array.is_null(row_idx) { return Ok(Value::Null); } match array.data_type() { - DataType::Boolean => { - let v = array + DataType::Boolean => Ok(Value::Bool( + array .as_any() .downcast_ref::() - .ok_or("invalid boolean array")? - .value(row_idx); - Ok(Value::Bool(v)) - } - DataType::Int8 => { - let v = array + .ok_or_else(|| ConvertError::internal("invalid boolean array"))? + .value(row_idx), + )), + DataType::Int8 => Ok(json_number( + array .as_any() .downcast_ref::() - .ok_or("invalid int8 array")? - .value(row_idx) as i64; - Ok(json_number(v)) - } - DataType::Int16 => { - let v = array + .ok_or_else(|| ConvertError::internal("invalid int8 array"))? + .value(row_idx) as i64, + )), + DataType::Int16 => Ok(json_number( + array .as_any() .downcast_ref::() - .ok_or("invalid int16 array")? - .value(row_idx) as i64; - Ok(json_number(v)) - } - DataType::Int32 => { - let v = array + .ok_or_else(|| ConvertError::internal("invalid int16 array"))? + .value(row_idx) as i64, + )), + DataType::Int32 => Ok(json_number( + array .as_any() .downcast_ref::() - .ok_or("invalid int32 array")? - .value(row_idx) as i64; - Ok(json_number(v)) - } - DataType::Int64 => { - let v = array + .ok_or_else(|| ConvertError::internal("invalid int32 array"))? + .value(row_idx) as i64, + )), + DataType::Int64 => Ok(json_number( + array .as_any() .downcast_ref::() - .ok_or("invalid int64 array")? - .value(row_idx); - Ok(json_number(v)) - } - DataType::UInt8 => { - let v = array + .ok_or_else(|| ConvertError::internal("invalid int64 array"))? + .value(row_idx), + )), + DataType::UInt8 => Ok(json_number( + array .as_any() .downcast_ref::() - .ok_or("invalid uint8 array")? - .value(row_idx) as i64; - Ok(json_number(v)) - } - DataType::UInt16 => { - let v = array + .ok_or_else(|| ConvertError::internal("invalid uint8 array"))? + .value(row_idx) as i64, + )), + DataType::UInt16 => Ok(json_number( + array .as_any() .downcast_ref::() - .ok_or("invalid uint16 array")? - .value(row_idx) as i64; - Ok(json_number(v)) - } - DataType::UInt32 => { - let v = array + .ok_or_else(|| ConvertError::internal("invalid uint16 array"))? + .value(row_idx) as i64, + )), + DataType::UInt32 => Ok(json_number( + array .as_any() .downcast_ref::() - .ok_or("invalid uint32 array")? - .value(row_idx) as i64; - Ok(json_number(v)) - } - DataType::UInt64 => { - let v = array + .ok_or_else(|| ConvertError::internal("invalid uint32 array"))? + .value(row_idx) as i64, + )), + DataType::UInt64 => Ok(json_number( + array .as_any() .downcast_ref::() - .ok_or("invalid uint64 array")? - .value(row_idx) as i64; - Ok(json_number(v)) - } + .ok_or_else(|| ConvertError::internal("invalid uint64 array"))? + .value(row_idx) as i64, + )), DataType::Float16 => { let v = array .as_any() .downcast_ref::() - .ok_or("invalid float16 array")? + .ok_or_else(|| ConvertError::internal("invalid float16 array"))? .value(row_idx) .to_f32() as f64; Ok(Number::from_f64(v) @@ -567,7 +962,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid float32 array")? + .ok_or_else(|| ConvertError::internal("invalid float32 array"))? .value(row_idx) as f64; Ok(Number::from_f64(v) .map(Value::Number) @@ -577,7 +972,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid float64 array")? + .ok_or_else(|| ConvertError::internal("invalid float64 array"))? .value(row_idx); Ok(Number::from_f64(v) .map(Value::Number) @@ -587,7 +982,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid utf8 array")? + .ok_or_else(|| ConvertError::internal("invalid utf8 array"))? .value(row_idx); Ok(Value::String(v.to_string())) } @@ -595,7 +990,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid large utf8 array")? + .ok_or_else(|| ConvertError::internal("invalid large utf8 array"))? .value(row_idx); Ok(Value::String(v.to_string())) } @@ -603,7 +998,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid binary array")? + .ok_or_else(|| ConvertError::internal("invalid binary array"))? .value(row_idx); Ok(Value::String(STANDARD.encode(v))) } @@ -611,7 +1006,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid large binary array")? + .ok_or_else(|| ConvertError::internal("invalid large binary array"))? .value(row_idx); Ok(Value::String(STANDARD.encode(v))) } @@ -619,7 +1014,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid fixed size binary array")? + .ok_or_else(|| ConvertError::internal("invalid fixed size binary array"))? .value(row_idx); Ok(Value::String(STANDARD.encode(v))) } @@ -630,24 +1025,24 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result>() - .ok_or("invalid list array")?; + .ok_or_else(|| ConvertError::internal("invalid list array"))?; l.value(row_idx) } DataType::LargeList(_) => { let l = array .as_any() .downcast_ref::>() - .ok_or("invalid large list array")?; + .ok_or_else(|| ConvertError::internal("invalid large list array"))?; l.value(row_idx) } DataType::FixedSizeList(_, _) => { let l = array .as_any() .downcast_ref::() - .ok_or("invalid fixed size list array")?; + .ok_or_else(|| ConvertError::internal("invalid fixed size list array"))?; l.value(row_idx) } - _ => return Err("invalid list data type"), + _ => return Err(ConvertError::internal("invalid list type")), }; for i in 0..values.len() { out.push(arrow_value_to_json(values.as_ref(), i)?); @@ -658,7 +1053,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid struct array")?; + .ok_or_else(|| ConvertError::internal("invalid struct array"))?; let mut map = Map::new(); for (i, f) in fields.iter().enumerate() { let col = struct_array.column(i); @@ -673,7 +1068,7 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid decimal128 array")? + .ok_or_else(|| ConvertError::internal("invalid decimal128 array"))? .value_as_string(row_idx); Ok(Value::String(v)) } @@ -681,18 +1076,18 @@ fn arrow_value_to_json(array: &dyn Array, row_idx: usize) -> Result() - .ok_or("invalid decimal256 array")? + .ok_or_else(|| ConvertError::internal("invalid decimal256 array"))? .value_as_string(row_idx); Ok(Value::String(v)) } DataType::Dictionary(_, _) => { let dict = array .as_any_dictionary_opt() - .ok_or("invalid dictionary array")?; + .ok_or_else(|| ConvertError::internal("invalid dictionary array"))?; let value_idx = dictionary_key_to_usize(dict.keys(), row_idx)?; let values = dict.values().as_ref(); if value_idx >= values.len() { - return Err("dictionary key out of range"); + return Err(ConvertError::internal("dictionary key out of range")); } arrow_value_to_json(values, value_idx) } @@ -707,60 +1102,62 @@ fn json_number(v: i64) -> Value { Value::Number(Number::from(v)) } -fn dictionary_key_to_usize(keys: &dyn Array, row_idx: usize) -> Result { +fn dictionary_key_to_usize(keys: &dyn Array, row_idx: usize) -> Result { match keys.data_type() { DataType::Int8 => { let v = keys .as_any() .downcast_ref::() - .ok_or("invalid dictionary keys (int8)")? + .ok_or_else(|| ConvertError::internal("invalid dictionary keys (int8)"))? .value(row_idx) as i64; - usize::try_from(v).map_err(|_| "negative dictionary key") + usize::try_from(v).map_err(|_| ConvertError::internal("negative dictionary key")) } DataType::Int16 => { let v = keys .as_any() .downcast_ref::() - .ok_or("invalid dictionary keys (int16)")? + .ok_or_else(|| ConvertError::internal("invalid dictionary keys (int16)"))? .value(row_idx) as i64; - usize::try_from(v).map_err(|_| "negative dictionary key") + usize::try_from(v).map_err(|_| ConvertError::internal("negative dictionary key")) } DataType::Int32 => { let v = keys .as_any() .downcast_ref::() - .ok_or("invalid dictionary keys (int32)")? + .ok_or_else(|| ConvertError::internal("invalid dictionary keys (int32)"))? .value(row_idx) as i64; - usize::try_from(v).map_err(|_| "negative dictionary key") + usize::try_from(v).map_err(|_| ConvertError::internal("negative dictionary key")) } DataType::Int64 => { let v = keys .as_any() .downcast_ref::() - .ok_or("invalid dictionary keys (int64)")? + .ok_or_else(|| ConvertError::internal("invalid dictionary keys (int64)"))? .value(row_idx); - usize::try_from(v).map_err(|_| "negative dictionary key") + usize::try_from(v).map_err(|_| ConvertError::internal("negative dictionary key")) } DataType::UInt8 => Ok(keys .as_any() .downcast_ref::() - .ok_or("invalid dictionary keys (uint8)")? + .ok_or_else(|| ConvertError::internal("invalid dictionary keys (uint8)"))? .value(row_idx) as usize), DataType::UInt16 => Ok(keys .as_any() .downcast_ref::() - .ok_or("invalid dictionary keys (uint16)")? + .ok_or_else(|| ConvertError::internal("invalid dictionary keys (uint16)"))? .value(row_idx) as usize), DataType::UInt32 => Ok(keys .as_any() .downcast_ref::() - .ok_or("invalid dictionary keys (uint32)")? + .ok_or_else(|| ConvertError::internal("invalid dictionary keys (uint32)"))? .value(row_idx) as usize), DataType::UInt64 => Ok(keys .as_any() .downcast_ref::() - .ok_or("invalid dictionary keys (uint64)")? + .ok_or_else(|| ConvertError::internal("invalid dictionary keys (uint64)"))? .value(row_idx) as usize), - _ => Err("unsupported dictionary key type"), + _ => Err(ConvertError::unsupported_type( + "unsupported dictionary key type", + )), } } diff --git a/src/fdw/scan.rs b/src/fdw/scan.rs index 26a796c..b83890f 100644 --- a/src/fdw/scan.rs +++ b/src/fdw/scan.rs @@ -1,9 +1,10 @@ -use crate::fdw::convert::arrow_value_to_datum; +use crate::fdw::convert::{arrow_value_to_datum, validate_arrow_type_for_pg_oid, ConvertErrorKind}; use crate::fdw::options::LanceFdwOptions; use futures::StreamExt; use lance_rs::dataset::scanner::DatasetRecordBatchStream; use lance_rs::Dataset; use pgrx::pg_sys; +use pgrx::{ereport, PgSqlErrorCode}; use std::ffi::CString; use std::pin::Pin; use std::sync::Arc; @@ -17,9 +18,22 @@ pub struct LanceScanState { current_batch: Option, current_row: usize, atttypids: Vec, + attnames: Vec>, att_to_batch_col: Vec>, } +fn pg_type_name(oid: pg_sys::Oid) -> String { + unsafe { + let ptr = pg_sys::format_type_be(oid); + if ptr.is_null() { + return format!("oid {}", oid); + } + let s = std::ffi::CStr::from_ptr(ptr).to_string_lossy().to_string(); + pg_sys::pfree(ptr.cast()); + s + } +} + #[pgrx::pg_guard] pub unsafe extern "C-unwind" fn get_foreign_rel_size( _root: *mut pg_sys::PlannerInfo, @@ -108,20 +122,35 @@ pub unsafe extern "C-unwind" fn begin_foreign_scan( let relid = (*relation).rd_id; let opts = LanceFdwOptions::from_foreign_table(relid).unwrap_or_else(|e| { - pgrx::error!("invalid foreign table options: {}", e); + ereport!( + ERROR, + PgSqlErrorCode::ERRCODE_FDW_INVALID_OPTION_NAME, + "invalid foreign table options", + format!("relation_oid={} error={}", relid, e), + ); }); let runtime = Arc::new(Runtime::new().unwrap_or_else(|e| { - pgrx::error!("failed to create tokio runtime: {}", e); + ereport!( + ERROR, + PgSqlErrorCode::ERRCODE_FDW_UNABLE_TO_CREATE_EXECUTION, + "failed to create tokio runtime", + format!("error={}", e), + ); })); let dataset = runtime .block_on(async { Dataset::open(&opts.uri).await }) .unwrap_or_else(|e| { - pgrx::error!("failed to open dataset {}: {}", opts.uri, e); + ereport!( + ERROR, + PgSqlErrorCode::ERRCODE_FDW_TABLE_NOT_FOUND, + "failed to open lance dataset", + format!("uri={} error={}", opts.uri, e), + ); }); - let stream = create_stream(&runtime, &dataset, opts.batch_size); + let stream = create_stream(&runtime, &dataset, &opts.uri, opts.batch_size); let tupdesc = (*relation).rd_att; if tupdesc.is_null() { @@ -143,22 +172,62 @@ pub unsafe extern "C-unwind" fn begin_foreign_scan( } } - let dataset_field_names: Vec = dataset - .schema() - .fields - .iter() - .map(|f| f.name.clone()) - .collect(); + let dataset_fields = &dataset.schema().fields; + let dataset_field_names: Vec = dataset_fields.iter().map(|f| f.name.clone()).collect(); + let mut name_to_idx = std::collections::BTreeMap::::new(); + for (idx, f) in dataset_fields.iter().enumerate() { + name_to_idx.insert(f.name.clone(), idx); + } let mut att_to_batch_col = Vec::with_capacity(natts); - for name in attnames { + for (att_idx, name) in attnames.iter().enumerate() { if let Some(name) = name { - let idx = dataset_field_names - .iter() - .position(|n| n == &name) - .unwrap_or_else(|| { - pgrx::error!("column not found in dataset schema: {}", name); - }); + let idx = name_to_idx.get(name).copied().unwrap_or_else(|| { + ereport!( + ERROR, + PgSqlErrorCode::ERRCODE_FDW_COLUMN_NAME_NOT_FOUND, + "column not found in lance dataset schema", + format!( + "uri={} column={} dataset_columns={}", + opts.uri, + name, + dataset_field_names.join(",") + ), + ); + }); + + let field = &dataset_fields[idx]; + let arrow_type = field.data_type(); + if let Err(e) = validate_arrow_type_for_pg_oid(&arrow_type, atttypids[att_idx]) { + let (errcode, message) = match e.kind { + ConvertErrorKind::TypeMismatch => ( + PgSqlErrorCode::ERRCODE_FDW_INVALID_DATA_TYPE_DESCRIPTORS, + "column type mismatch between foreign table and dataset schema", + ), + ConvertErrorKind::UnsupportedType | ConvertErrorKind::ValueOutOfRange => ( + PgSqlErrorCode::ERRCODE_FDW_INVALID_DATA_TYPE, + "unsupported column type for lance_fdw", + ), + ConvertErrorKind::Internal => ( + PgSqlErrorCode::ERRCODE_FDW_ERROR, + "internal lance_fdw schema validation error", + ), + }; + ereport!( + ERROR, + errcode, + message, + format!( + "uri={} column={} arrow_type={} pg_type={} error={}", + opts.uri, + name, + arrow_type, + pg_type_name(atttypids[att_idx]), + e + ), + ); + } + att_to_batch_col.push(Some(idx)); } else { att_to_batch_col.push(None); @@ -173,6 +242,7 @@ pub unsafe extern "C-unwind" fn begin_foreign_scan( current_batch: None, current_row: 0, atttypids, + attnames: attnames.clone(), att_to_batch_col, }); @@ -210,7 +280,14 @@ pub unsafe extern "C-unwind" fn iterate_foreign_scan( let next = state.runtime.block_on(async { state.stream.next().await }); match next { None => return slot, - Some(Err(e)) => pgrx::error!("failed to read next batch: {}", e), + Some(Err(e)) => { + ereport!( + ERROR, + PgSqlErrorCode::ERRCODE_FDW_ERROR, + "failed to read next record batch", + format!("uri={} error={}", state.opts.uri, e), + ); + } Some(Ok(batch)) => { state.current_batch = Some(batch); state.current_row = 0; @@ -250,12 +327,48 @@ pub unsafe extern "C-unwind" fn iterate_foreign_scan( .copied() .flatten() .unwrap_or_else(|| { - pgrx::error!("missing batch column mapping for attribute {}", i + 1); + ereport!( + ERROR, + PgSqlErrorCode::ERRCODE_FDW_INCONSISTENT_DESCRIPTOR_INFORMATION, + "missing batch column mapping for attribute", + format!("attribute_number={}", i + 1), + ); }); let col = batch.column(batch_idx); let (datum, isnull) = arrow_value_to_datum(col.as_ref(), row, state.atttypids[i]) .unwrap_or_else(|e| { - pgrx::error!("failed to convert column {}: {}", i + 1, e); + let col_name = state + .attnames + .get(i) + .and_then(|v| v.as_deref()) + .unwrap_or(""); + let (errcode, message) = match e.kind { + ConvertErrorKind::TypeMismatch => ( + PgSqlErrorCode::ERRCODE_FDW_INVALID_DATA_TYPE_DESCRIPTORS, + "column type mismatch between foreign table and dataset schema", + ), + ConvertErrorKind::UnsupportedType | ConvertErrorKind::ValueOutOfRange => ( + PgSqlErrorCode::ERRCODE_FDW_INVALID_DATA_TYPE, + "unsupported column type for lance_fdw", + ), + ConvertErrorKind::Internal => ( + PgSqlErrorCode::ERRCODE_FDW_ERROR, + "internal lance_fdw conversion error", + ), + }; + ereport!( + ERROR, + errcode, + message, + format!( + "uri={} column={} arrow_type={} pg_type={} error={}", + state.opts.uri, + col_name, + col.data_type(), + pg_type_name(state.atttypids[i]), + e + ), + ); }); *(*slot).tts_values.add(i) = datum; *(*slot).tts_isnull.add(i) = isnull; @@ -277,7 +390,12 @@ pub unsafe extern "C-unwind" fn rescan_foreign_scan(node: *mut pg_sys::ForeignSc return; } let state = &mut *state_ptr; - let stream = create_stream(&state.runtime, &state.dataset, state.opts.batch_size); + let stream = create_stream( + &state.runtime, + &state.dataset, + &state.opts.uri, + state.opts.batch_size, + ); state.stream = Box::pin(stream); state.current_batch = None; state.current_row = 0; @@ -388,11 +506,19 @@ fn format_projection_list(relation: *mut pg_sys::RelationData) -> String { fn create_stream( runtime: &Arc, dataset: &Dataset, + uri: &str, batch_size: usize, ) -> DatasetRecordBatchStream { let mut scanner = dataset.scan(); scanner.batch_size(batch_size); runtime .block_on(async { scanner.try_into_stream().await }) - .unwrap_or_else(|e| pgrx::error!("failed to create scanner stream: {}", e)) + .unwrap_or_else(|e| { + ereport!( + ERROR, + PgSqlErrorCode::ERRCODE_FDW_UNABLE_TO_CREATE_EXECUTION, + "failed to create scanner stream", + format!("uri={} error={}", uri, e), + ); + }) } diff --git a/src/tests.rs b/src/tests.rs index 9cfba52..17d2cbd 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -5,7 +5,7 @@ mod tests { use super::*; use arrow::array::{ builder::StringDictionaryBuilder, Array, BooleanArray, Decimal128Array, Float32Array, - Int32Array, ListBuilder, StringArray, StructArray, UInt16Array, UInt32Array, + Int32Array, ListBuilder, StringArray, StructArray, UInt16Array, UInt32Array, UInt64Array, }; use arrow::datatypes::{DataType, Field, Int32Type, Schema}; use arrow::record_batch::RecordBatch; @@ -124,6 +124,11 @@ mod tests { out } + fn quote_literal(value: &str) -> String { + let escaped = value.replace('\'', "''"); + format!("'{}'", escaped) + } + fn list_slt_files(dir: &Path) -> Vec { let mut files: Vec = fs::read_dir(dir) .expect("read_dir tests/sql") @@ -255,6 +260,33 @@ mod tests { Ok(table_path) } + + fn create_table_with_u64_overflow( + &self, + ) -> Result> { + let table_path = self.temp_dir.path().join("fdw_u64_overflow"); + + let id_array = Int32Array::from(vec![1, 2, 3]); + let u64_array = UInt64Array::from(vec![u64::MAX, u64::MAX, u64::MAX]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("u64", DataType::UInt64, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(id_array), Arc::new(u64_array)], + )?; + + let reader = arrow::record_batch::RecordBatchIterator::new(vec![Ok(batch)], schema); + let rt = tokio::runtime::Runtime::new().unwrap(); + rt.block_on(async { + Dataset::write(reader, table_path.to_str().unwrap(), None).await + })?; + + Ok(table_path) + } } #[pg_test] @@ -353,4 +385,124 @@ CREATE SERVER {server} FOREIGN DATA WRAPPER lance_fdw;\n\n" Spi::run("SELECT pg_advisory_unlock(424242)").expect("advisory unlock"); } + + #[pg_test] + fn test_fdw_diagnostic_errors() { + Spi::run("SELECT pg_advisory_lock(424242)").expect("advisory lock"); + + let gen = LanceTestDataGenerator::new().expect("generator"); + let struct_list_path = gen + .create_table_with_struct_and_list() + .expect("create table"); + let struct_list_uri = struct_list_path.to_str().unwrap(); + + let overflow_path = gen.create_table_with_u64_overflow().expect("create table"); + let overflow_uri = overflow_path.to_str().unwrap(); + + Spi::run("DROP SCHEMA IF EXISTS slt_unit CASCADE").expect("drop schema"); + Spi::run("CREATE SCHEMA slt_unit").expect("create schema"); + Spi::run("SET search_path TO slt_unit, public").expect("set search_path"); + Spi::run("DROP SERVER IF EXISTS lance_srv_unit CASCADE").expect("drop server"); + Spi::run("CREATE SERVER lance_srv_unit FOREIGN DATA WRAPPER lance_fdw") + .expect("create server"); + + Spi::run( + "CREATE OR REPLACE FUNCTION slt_unit.capture_error(sql text) \ + RETURNS jsonb LANGUAGE plpgsql AS $$ \ + DECLARE st text; msg text; det text; \ + BEGIN \ + EXECUTE sql; \ + RETURN NULL; \ + EXCEPTION WHEN OTHERS THEN \ + GET STACKED DIAGNOSTICS st = RETURNED_SQLSTATE, msg = MESSAGE_TEXT, det = PG_EXCEPTION_DETAIL; \ + RETURN jsonb_build_object('sqlstate', st, 'message', msg, 'detail', det); \ + END $$;", + ) + .expect("create capture_error"); + + let create_missing = format!( + "CREATE FOREIGN TABLE slt_unit.t_missing(id int4, missing_col text) \ + SERVER lance_srv_unit OPTIONS (uri {});", + quote_literal(struct_list_uri) + ); + Spi::run(&create_missing).expect("create t_missing"); + let st = Spi::get_one::(&format!( + "SELECT capture_error({})->>'sqlstate'", + quote_literal("SELECT count(*) FROM slt_unit.t_missing") + )) + .expect("capture missing") + .expect("state"); + assert_eq!(st, "HV005"); + + let create_mismatch = format!( + "CREATE FOREIGN TABLE slt_unit.t_mismatch(id text) \ + SERVER lance_srv_unit OPTIONS (uri {});", + quote_literal(struct_list_uri) + ); + Spi::run(&create_mismatch).expect("create t_mismatch"); + let st = Spi::get_one::(&format!( + "SELECT capture_error({})->>'sqlstate'", + quote_literal("SELECT count(*) FROM slt_unit.t_mismatch") + )) + .expect("capture mismatch") + .expect("state"); + assert_eq!(st, "HV006"); + + Spi::run("CREATE TYPE slt_unit.meta_bad AS (score float4, missing text);") + .expect("create meta_bad"); + let create_struct_mismatch = format!( + "CREATE FOREIGN TABLE slt_unit.t_struct_mismatch(id int4, meta slt_unit.meta_bad) \ + SERVER lance_srv_unit OPTIONS (uri {});", + quote_literal(struct_list_uri) + ); + Spi::run(&create_struct_mismatch).expect("create t_struct_mismatch"); + let st = Spi::get_one::(&format!( + "SELECT capture_error({})->>'sqlstate'", + quote_literal("SELECT count(*) FROM slt_unit.t_struct_mismatch") + )) + .expect("capture struct mismatch") + .expect("state"); + assert_eq!(st, "HV006"); + + let bad_uri = gen.temp_dir.path().join("does_not_exist"); + let create_bad_uri = format!( + "CREATE FOREIGN TABLE slt_unit.t_bad_uri(id int4) \ + SERVER lance_srv_unit OPTIONS (uri {});", + quote_literal(bad_uri.to_str().unwrap()) + ); + Spi::run(&create_bad_uri).expect("create t_bad_uri"); + let st = Spi::get_one::(&format!( + "SELECT capture_error({})->>'sqlstate'", + quote_literal("SELECT count(*) FROM slt_unit.t_bad_uri") + )) + .expect("capture bad uri") + .expect("state"); + assert_eq!(st, "HV00R"); + + let create_overflow = format!( + "CREATE FOREIGN TABLE slt_unit.t_u64_overflow(id int4, u64 int8) \ + SERVER lance_srv_unit OPTIONS (uri {});", + quote_literal(overflow_uri) + ); + Spi::run(&create_overflow).expect("create t_u64_overflow"); + let st = Spi::get_one::(&format!( + "SELECT capture_error({})->>'sqlstate'", + quote_literal("SELECT count(*) FROM slt_unit.t_u64_overflow") + )) + .expect("capture overflow") + .expect("state"); + assert_eq!(st, "HV004"); + + let detail = Spi::get_one::(&format!( + "SELECT capture_error({})->>'detail'", + quote_literal("SELECT count(*) FROM slt_unit.t_u64_overflow") + )) + .expect("capture overflow detail") + .expect("detail"); + assert!(detail.contains("arrow_type=")); + assert!(detail.contains("pg_type=")); + assert!(detail.contains("uri=")); + + Spi::run("SELECT pg_advisory_unlock(424242)").expect("advisory unlock"); + } }