1515use std:: sync:: Arc ;
1616
1717use databend_common_exception:: Result ;
18+ use databend_common_expression:: BlockEntry ;
19+ use databend_common_expression:: Column ;
1820use databend_common_meta_app:: principal:: BUILTIN_ROLE_ACCOUNT_ADMIN ;
1921use databend_common_version:: BUILD_INFO ;
2022use databend_query:: sessions:: BuildInfoRef ;
@@ -32,33 +34,66 @@ fn resolve_file_path(path: &str) -> String {
3234 if path. contains ( "://" ) {
3335 return path. to_owned ( ) ;
3436 }
37+
3538 if path. starts_with ( '/' ) {
3639 return format ! ( "fs://{}" , path) ;
3740 }
41+
3842 format ! (
3943 "fs://{}/{}" ,
4044 std:: env:: current_dir( ) . unwrap( ) . to_str( ) . unwrap( ) ,
4145 path
4246 )
4347}
4448
45- /// Extract the real filesystem path from a `fs://` URI.
46- fn fs_path_from_uri ( uri : & str ) -> Option < & str > {
47- uri. strip_prefix ( "fs://" )
49+ fn extract_string_column (
50+ entry : & BlockEntry ,
51+ ) -> Option < & databend_common_expression:: types:: StringColumn > {
52+ match entry {
53+ BlockEntry :: Column ( Column :: String ( col) ) => Some ( col) ,
54+ BlockEntry :: Column ( Column :: Nullable ( n) ) => match & n. column {
55+ Column :: String ( col) => Some ( col) ,
56+ _ => None ,
57+ } ,
58+ _ => None ,
59+ }
60+ }
61+
62+ fn build_infer_schema_sql (
63+ file_path : & str ,
64+ file_format : & str ,
65+ pattern : Option < & str > ,
66+ connection : Option < & str > ,
67+ ) -> String {
68+ let connection_clause = connection
69+ . map ( |c| format ! ( ", connection_name => '{}'" , c) )
70+ . unwrap_or_default ( ) ;
71+ let pattern_clause = pattern
72+ . map ( |p| format ! ( ", pattern => '{}'" , p) )
73+ . unwrap_or_default ( ) ;
74+
75+ format ! (
76+ "SELECT column_name FROM infer_schema(location => '{}', file_format => '{}'{}{})" ,
77+ file_path,
78+ file_format. to_uppercase( ) ,
79+ pattern_clause,
80+ connection_clause
81+ )
4882}
4983
50- /// Read the header line of a CSV file and return column names.
51- fn read_csv_column_names ( path : & str ) -> std:: io:: Result < Vec < String > > {
52- use std:: io:: BufRead ;
53- let file = std:: fs:: File :: open ( path) ?;
54- let mut reader = std:: io:: BufReader :: new ( file) ;
55- let mut header = String :: new ( ) ;
56- reader. read_line ( & mut header) ?;
57- Ok ( header
58- . trim ( )
59- . split ( ',' )
60- . map ( |s| s. trim ( ) . trim_matches ( '"' ) . to_string ( ) )
61- . collect ( ) )
84+ fn build_position_select ( col_names : & [ String ] ) -> PyResult < String > {
85+ if col_names. is_empty ( ) {
86+ return Err ( PyErr :: new :: < pyo3:: exceptions:: PyRuntimeError , _ > (
87+ "Could not infer schema: no columns found" ,
88+ ) ) ;
89+ }
90+
91+ Ok ( col_names
92+ . iter ( )
93+ . enumerate ( )
94+ . map ( |( i, name) | format ! ( "${} AS `{}`" , i + 1 , name) )
95+ . collect :: < Vec < _ > > ( )
96+ . join ( ", " ) )
6297}
6398
6499#[ pyclass( name = "SessionContext" , module = "databend" , subclass) ]
@@ -214,47 +249,21 @@ impl PySessionContext {
214249 let pattern_clause = pattern
215250 . map ( |p| format ! ( ", pattern => '{}'" , p) )
216251 . unwrap_or_default ( ) ;
217-
218252 let select_clause = match file_format {
219- "csv" => self . build_column_select ( & file_path) ?,
253+ "csv" | "tsv" => {
254+ self . build_column_select ( & file_path, file_format, pattern, connection, py) ?
255+ }
220256 _ => "*" . to_string ( ) ,
221257 } ;
222-
223258 let sql = format ! (
224259 "create view {} as select {} from '{}' (file_format => '{}'{}{})" ,
225260 name, select_clause, file_path, file_format, pattern_clause, connection_clause
226261 ) ;
262+
227263 let _ = self . sql ( & sql, py) ?. collect ( py) ?;
228264 Ok ( ( ) )
229265 }
230266
231- /// Read CSV header from local file and build `$1 AS col1, $2 AS col2, ...`.
232- fn build_column_select ( & self , file_path : & str ) -> PyResult < String > {
233- let fs_path = fs_path_from_uri ( file_path) . ok_or_else ( || {
234- PyErr :: new :: < pyo3:: exceptions:: PyRuntimeError , _ > ( format ! (
235- "CSV column inference only supports local files (fs://), got: {}" ,
236- file_path
237- ) )
238- } ) ?;
239- let col_names = read_csv_column_names ( fs_path) . map_err ( |e| {
240- PyErr :: new :: < pyo3:: exceptions:: PyRuntimeError , _ > ( format ! (
241- "Failed to read CSV header: {}" ,
242- e
243- ) )
244- } ) ?;
245- if col_names. is_empty ( ) {
246- return Err ( PyErr :: new :: < pyo3:: exceptions:: PyRuntimeError , _ > (
247- "Could not infer schema: no columns found" ,
248- ) ) ;
249- }
250- Ok ( col_names
251- . iter ( )
252- . enumerate ( )
253- . map ( |( i, name) | format ! ( "${} AS `{}`" , i + 1 , name) )
254- . collect :: < Vec < _ > > ( )
255- . join ( ", " ) )
256- }
257-
258267 #[ pyo3( signature = ( name, access_key_id, secret_access_key, endpoint_url = None , region = None ) ) ]
259268 fn create_s3_connection (
260269 & mut self ,
@@ -398,8 +407,59 @@ impl PySessionContext {
398407 }
399408}
400409
410+ impl PySessionContext {
411+ fn build_column_select (
412+ & mut self ,
413+ file_path : & str ,
414+ file_format : & str ,
415+ pattern : Option < & str > ,
416+ connection : Option < & str > ,
417+ py : Python ,
418+ ) -> PyResult < String > {
419+ let sql = build_infer_schema_sql ( file_path, file_format, pattern, connection) ;
420+ let blocks = self . sql ( & sql, py) ?. collect ( py) ?;
421+
422+ let col_names = blocks
423+ . blocks
424+ . iter ( )
425+ . filter ( |b| b. num_rows ( ) > 0 )
426+ . filter_map ( |b| extract_string_column ( b. get_by_offset ( 0 ) ) )
427+ . flat_map ( |col| col. iter ( ) . map ( |s| s. to_string ( ) ) )
428+ . collect :: < Vec < _ > > ( ) ;
429+
430+ build_position_select ( & col_names)
431+ }
432+ }
433+
401434async fn plan_sql ( ctx : & Arc < QueryContext > , sql : & str ) -> Result < PyDataFrame > {
402435 let mut planner = Planner :: new ( ctx. clone ( ) ) ;
403436 let ( plan, _) = planner. plan_sql ( sql) . await ?;
404437 Ok ( PyDataFrame :: new ( ctx. clone ( ) , plan, default_box_size ( ) ) )
405438}
439+
440+ #[ cfg( test) ]
441+ mod tests {
442+ use super :: * ;
443+
444+ #[ test]
445+ fn test_resolve_file_path_absolute ( ) {
446+ assert_eq ! ( resolve_file_path( "/tmp/data.csv" ) , "fs:///tmp/data.csv" ) ;
447+ }
448+
449+ #[ test]
450+ fn test_build_position_select ( ) {
451+ let col_names = vec ! [ "column_1" . to_string( ) , "column_2" . to_string( ) ] ;
452+ let select = build_position_select ( & col_names) . unwrap ( ) ;
453+ assert_eq ! ( select, "$1 AS `column_1`, $2 AS `column_2`" ) ;
454+ }
455+
456+ #[ test]
457+ fn test_build_infer_schema_sql_with_pattern_and_connection ( ) {
458+ let sql = build_infer_schema_sql ( "s3://bucket/logs/" , "tsv" , Some ( "*.tsv" ) , Some ( "my_s3" ) ) ;
459+
460+ assert_eq ! (
461+ sql,
462+ "SELECT column_name FROM infer_schema(location => 's3://bucket/logs/', file_format => 'TSV', pattern => '*.tsv', connection_name => 'my_s3')"
463+ ) ;
464+ }
465+ }
0 commit comments