Skip to content

Commit b3c5792

Browse files
committed
fix: infer CSV/TSV columns during stage registration
1 parent 62966a3 commit b3c5792

File tree

3 files changed

+154
-82
lines changed

3 files changed

+154
-82
lines changed

src/bendpy/src/context.rs

Lines changed: 105 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
use std::sync::Arc;
1616

1717
use databend_common_exception::Result;
18+
use databend_common_expression::BlockEntry;
19+
use databend_common_expression::Column;
1820
use databend_common_meta_app::principal::BUILTIN_ROLE_ACCOUNT_ADMIN;
1921
use databend_common_version::BUILD_INFO;
2022
use databend_query::sessions::BuildInfoRef;
@@ -32,33 +34,66 @@ fn resolve_file_path(path: &str) -> String {
3234
if path.contains("://") {
3335
return path.to_owned();
3436
}
37+
3538
if path.starts_with('/') {
3639
return format!("fs://{}", path);
3740
}
41+
3842
format!(
3943
"fs://{}/{}",
4044
std::env::current_dir().unwrap().to_str().unwrap(),
4145
path
4246
)
4347
}
4448

45-
/// Extract the real filesystem path from a `fs://` URI.
46-
fn fs_path_from_uri(uri: &str) -> Option<&str> {
47-
uri.strip_prefix("fs://")
49+
fn extract_string_column(
50+
entry: &BlockEntry,
51+
) -> Option<&databend_common_expression::types::StringColumn> {
52+
match entry {
53+
BlockEntry::Column(Column::String(col)) => Some(col),
54+
BlockEntry::Column(Column::Nullable(n)) => match &n.column {
55+
Column::String(col) => Some(col),
56+
_ => None,
57+
},
58+
_ => None,
59+
}
60+
}
61+
62+
fn build_infer_schema_sql(
63+
file_path: &str,
64+
file_format: &str,
65+
pattern: Option<&str>,
66+
connection: Option<&str>,
67+
) -> String {
68+
let connection_clause = connection
69+
.map(|c| format!(", connection_name => '{}'", c))
70+
.unwrap_or_default();
71+
let pattern_clause = pattern
72+
.map(|p| format!(", pattern => '{}'", p))
73+
.unwrap_or_default();
74+
75+
format!(
76+
"SELECT column_name FROM infer_schema(location => '{}', file_format => '{}'{}{})",
77+
file_path,
78+
file_format.to_uppercase(),
79+
pattern_clause,
80+
connection_clause
81+
)
4882
}
4983

50-
/// Read the header line of a CSV file and return column names.
51-
fn read_csv_column_names(path: &str) -> std::io::Result<Vec<String>> {
52-
use std::io::BufRead;
53-
let file = std::fs::File::open(path)?;
54-
let mut reader = std::io::BufReader::new(file);
55-
let mut header = String::new();
56-
reader.read_line(&mut header)?;
57-
Ok(header
58-
.trim()
59-
.split(',')
60-
.map(|s| s.trim().trim_matches('"').to_string())
61-
.collect())
84+
fn build_position_select(col_names: &[String]) -> PyResult<String> {
85+
if col_names.is_empty() {
86+
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
87+
"Could not infer schema: no columns found",
88+
));
89+
}
90+
91+
Ok(col_names
92+
.iter()
93+
.enumerate()
94+
.map(|(i, name)| format!("${} AS `{}`", i + 1, name))
95+
.collect::<Vec<_>>()
96+
.join(", "))
6297
}
6398

6499
#[pyclass(name = "SessionContext", module = "databend", subclass)]
@@ -214,47 +249,21 @@ impl PySessionContext {
214249
let pattern_clause = pattern
215250
.map(|p| format!(", pattern => '{}'", p))
216251
.unwrap_or_default();
217-
218252
let select_clause = match file_format {
219-
"csv" => self.build_column_select(&file_path)?,
253+
"csv" | "tsv" => {
254+
self.build_column_select(&file_path, file_format, pattern, connection, py)?
255+
}
220256
_ => "*".to_string(),
221257
};
222-
223258
let sql = format!(
224259
"create view {} as select {} from '{}' (file_format => '{}'{}{})",
225260
name, select_clause, file_path, file_format, pattern_clause, connection_clause
226261
);
262+
227263
let _ = self.sql(&sql, py)?.collect(py)?;
228264
Ok(())
229265
}
230266

231-
/// Read CSV header from local file and build `$1 AS col1, $2 AS col2, ...`.
232-
fn build_column_select(&self, file_path: &str) -> PyResult<String> {
233-
let fs_path = fs_path_from_uri(file_path).ok_or_else(|| {
234-
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
235-
"CSV column inference only supports local files (fs://), got: {}",
236-
file_path
237-
))
238-
})?;
239-
let col_names = read_csv_column_names(fs_path).map_err(|e| {
240-
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
241-
"Failed to read CSV header: {}",
242-
e
243-
))
244-
})?;
245-
if col_names.is_empty() {
246-
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
247-
"Could not infer schema: no columns found",
248-
));
249-
}
250-
Ok(col_names
251-
.iter()
252-
.enumerate()
253-
.map(|(i, name)| format!("${} AS `{}`", i + 1, name))
254-
.collect::<Vec<_>>()
255-
.join(", "))
256-
}
257-
258267
#[pyo3(signature = (name, access_key_id, secret_access_key, endpoint_url = None, region = None))]
259268
fn create_s3_connection(
260269
&mut self,
@@ -398,8 +407,59 @@ impl PySessionContext {
398407
}
399408
}
400409

410+
impl PySessionContext {
411+
fn build_column_select(
412+
&mut self,
413+
file_path: &str,
414+
file_format: &str,
415+
pattern: Option<&str>,
416+
connection: Option<&str>,
417+
py: Python,
418+
) -> PyResult<String> {
419+
let sql = build_infer_schema_sql(file_path, file_format, pattern, connection);
420+
let blocks = self.sql(&sql, py)?.collect(py)?;
421+
422+
let col_names = blocks
423+
.blocks
424+
.iter()
425+
.filter(|b| b.num_rows() > 0)
426+
.filter_map(|b| extract_string_column(b.get_by_offset(0)))
427+
.flat_map(|col| col.iter().map(|s| s.to_string()))
428+
.collect::<Vec<_>>();
429+
430+
build_position_select(&col_names)
431+
}
432+
}
433+
401434
async fn plan_sql(ctx: &Arc<QueryContext>, sql: &str) -> Result<PyDataFrame> {
402435
let mut planner = Planner::new(ctx.clone());
403436
let (plan, _) = planner.plan_sql(sql).await?;
404437
Ok(PyDataFrame::new(ctx.clone(), plan, default_box_size()))
405438
}
439+
440+
#[cfg(test)]
441+
mod tests {
442+
use super::*;
443+
444+
#[test]
445+
fn test_resolve_file_path_absolute() {
446+
assert_eq!(resolve_file_path("/tmp/data.csv"), "fs:///tmp/data.csv");
447+
}
448+
449+
#[test]
450+
fn test_build_position_select() {
451+
let col_names = vec!["column_1".to_string(), "column_2".to_string()];
452+
let select = build_position_select(&col_names).unwrap();
453+
assert_eq!(select, "$1 AS `column_1`, $2 AS `column_2`");
454+
}
455+
456+
#[test]
457+
fn test_build_infer_schema_sql_with_pattern_and_connection() {
458+
let sql = build_infer_schema_sql("s3://bucket/logs/", "tsv", Some("*.tsv"), Some("my_s3"));
459+
460+
assert_eq!(
461+
sql,
462+
"SELECT column_name FROM infer_schema(location => 's3://bucket/logs/', file_format => 'TSV', pattern => '*.tsv', connection_name => 'my_s3')"
463+
);
464+
}
465+
}

src/bendpy/tests/test_connections.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,6 @@ def register_parquet(self, name, path, pattern=None, connection=None):
4343
sql = f"create view {name} as select * from '{path}' (file_format => 'parquet'{pattern_clause})"
4444
self.sql(sql)
4545

46-
def register_csv(self, name, path, pattern=None, connection=None):
47-
file_path = path if connection else (f"fs://{path}" if path.startswith("/") else path)
48-
pattern_clause = f", pattern => '{pattern}'" if pattern else ""
49-
conn_clause = f", connection => '{connection}'" if connection else ""
50-
sql = f"create view {name} as select * from '{file_path}' (file_format => 'csv'{pattern_clause}{conn_clause})"
51-
self.sql(sql)
52-
5346
def create_azblob_connection(self, name, endpoint_url, account_name, account_key):
5447
sql = f"CREATE OR REPLACE CONNECTION {name} STORAGE_TYPE = 'AZBLOB' endpoint_url = '{endpoint_url}' account_name = '{account_name}' account_key = '{account_key}'"
5548
self.sql(sql)
@@ -245,15 +238,6 @@ def test_register_parquet_with_connection_and_pattern(self):
245238
expected_sql = "create view sales as select * from 's3://bucket/data/' (file_format => 'parquet', pattern => '*.parquet', connection => 'my_s3')"
246239
mock_sql.assert_called_once_with(expected_sql)
247240

248-
def test_register_csv_with_connection(self):
249-
with unittest.mock.patch.object(self.ctx, "sql") as mock_sql:
250-
mock_sql.return_value.collect.return_value = None
251-
252-
self.ctx.register_csv("users", "s3://bucket/users.csv", connection="my_s3")
253-
254-
expected_sql = "create view users as select * from 's3://bucket/users.csv' (file_format => 'csv', connection => 'my_s3')"
255-
mock_sql.assert_called_once_with(expected_sql)
256-
257241
def test_register_parquet_legacy_mode(self):
258242
with unittest.mock.patch.object(self.ctx, "sql") as mock_sql:
259243
mock_sql.return_value.collect.return_value = None
@@ -263,27 +247,6 @@ def test_register_parquet_legacy_mode(self):
263247
expected_sql = "create view local as select * from '/data/file.parquet' (file_format => 'parquet')"
264248
mock_sql.assert_called_once_with(expected_sql)
265249

266-
def test_register_csv_with_pattern_no_connection(self):
267-
with unittest.mock.patch.object(self.ctx, "sql") as mock_sql:
268-
mock_sql.return_value.collect.return_value = None
269-
270-
self.ctx.register_csv("logs", "/data/logs/", pattern="*.csv")
271-
272-
expected_sql = "create view logs as select * from 'fs:///data/logs/' (file_format => 'csv', pattern => '*.csv')"
273-
mock_sql.assert_called_once_with(expected_sql)
274-
275-
def test_register_csv_with_pattern_and_connection(self):
276-
with unittest.mock.patch.object(self.ctx, "sql") as mock_sql:
277-
mock_sql.return_value.collect.return_value = None
278-
279-
self.ctx.register_csv(
280-
"logs", "s3://bucket/logs/", pattern="*.csv", connection="my_s3"
281-
)
282-
283-
expected_sql = "create view logs as select * from 's3://bucket/logs/' (file_format => 'csv', pattern => '*.csv', connection => 'my_s3')"
284-
mock_sql.assert_called_once_with(expected_sql)
285-
286-
287250
class TestStages:
288251
def setup_method(self):
289252
self.ctx = MockSessionContext()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
## Copyright 2021 Datafuse Labs
2+
##
3+
## Licensed under the Apache License, Version 2.0 (the "License");
4+
## you may not use this file except in compliance with the License.
5+
## You may obtain a copy of the License at
6+
##
7+
## http://www.apache.org/licenses/LICENSE-2.0
8+
##
9+
## Unless required by applicable law or agreed to in writing, software
10+
## distributed under the License is distributed on an "AS IS" BASIS,
11+
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
## See the License for the specific language governing permissions and
13+
## limitations under the License.
14+
15+
from pathlib import Path
16+
17+
from databend import SessionContext
18+
19+
20+
ROOT = Path(__file__).resolve().parents[3]
21+
CSV_DIR = ROOT / "tests" / "data" / "csv"
22+
TSV_DIR = ROOT / "tests" / "data" / "tsv"
23+
CSV_PATH = CSV_DIR / "select.csv"
24+
TSV_PATH = TSV_DIR / "select.tsv"
25+
26+
27+
class TestRegisterDelimitedFiles:
28+
def setup_method(self):
29+
self.ctx = SessionContext()
30+
31+
def assert_view_rows(self, view_name):
32+
df = self.ctx.sql(f"select * from {view_name} order by column_1").to_pandas()
33+
assert df.values.tolist() == [[1, None, None], [2, "b", "B"], [3, "c", None]]
34+
35+
def test_register_csv_select_star(self):
36+
self.ctx.register_csv("csv_stage_view", str(CSV_PATH))
37+
self.assert_view_rows("csv_stage_view")
38+
39+
def test_register_csv_select_star_with_pattern(self):
40+
self.ctx.register_csv("csv_stage_pattern_view", str(CSV_DIR), pattern="select.csv")
41+
self.assert_view_rows("csv_stage_pattern_view")
42+
43+
def test_register_tsv_select_star(self):
44+
self.ctx.register_tsv("tsv_stage_view", str(TSV_PATH))
45+
self.assert_view_rows("tsv_stage_view")
46+
47+
def test_register_tsv_select_star_with_pattern(self):
48+
self.ctx.register_tsv("tsv_stage_pattern_view", str(TSV_DIR), pattern="select.tsv")
49+
self.assert_view_rows("tsv_stage_pattern_view")

0 commit comments

Comments
 (0)