3030import datetime
3131from types import NoneType
3232
33+ import sqlalchemy .engine .reflection
3334import sqlalchemy .types as sqltypes
3435from typing import Any , Dict , Optional , Union
3536from sqlalchemy import util as sa_util
4445 Subquery ,
4546)
4647from sqlalchemy .dialects .postgresql .base import PGCompiler , PGIdentifierPreparer
48+ from sqlalchemy import Table , MetaData , Column
4749from sqlalchemy .types import (
4850 BIGINT ,
4951 INTEGER ,
@@ -670,7 +672,7 @@ def process(value):
670672class DatabendDateTime (sqltypes .DATETIME ):
671673 __visit_name__ = "DATETIME"
672674
673- _reg = re .compile (r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)" )
675+ _reg = re .compile (r"(\d+)-(\d+)-(\d+) (\d+):(\d+):(\d+)\.(\d+) " )
674676
675677 def result_processor (self , dialect , coltype ):
676678 def process (value ):
@@ -698,7 +700,7 @@ def process(value):
698700class DatabendTime (sqltypes .TIME ):
699701 __visit_name__ = "TIME"
700702
701- _reg = re .compile (r"(?:\d+)-(?:\d+)-(?:\d+) (\d+):(\d+):(\d+)" )
703+ _reg = re .compile (r"(?:\d+)-(?:\d+)-(?:\d+) (\d+):(\d+):(\d+)\.(\d+) " )
702704
703705 def result_processor (self , dialect , coltype ):
704706 def process (value ):
@@ -720,7 +722,7 @@ def literal_processor(self, dialect):
720722 def process (value ):
721723 if value is not None :
722724 from_min_value = datetime .datetime .combine (
723- datetime .date (1000 , 1 , 1 ), value
725+ datetime .date (1970 , 1 , 1 ), value
724726 )
725727 time_str = from_min_value .isoformat (timespec = "microseconds" )
726728 return f"'{ time_str } '"
@@ -800,6 +802,9 @@ class DatabendIdentifierPreparer(PGIdentifierPreparer):
800802
801803
802804class DatabendCompiler (PGCompiler ):
805+ iscopyintotable : bool = False
806+ iscopyintolocation : bool = False
807+
803808 def get_select_precolumns (self , select , ** kw ):
804809 # call the base implementation because Databend doesn't support DISTINCT ON
805810 return super (PGCompiler , self ).get_select_precolumns (select , ** kw )
@@ -971,6 +976,11 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw):
971976 )
972977
973978 def visit_copy_into (self , copy_into , ** kw ):
979+ if isinstance (copy_into .target , (TableClause ,)):
980+ self .iscopyintotable = True
981+ else :
982+ self .iscopyintolocation = True
983+
974984 target = (
975985 self .preparer .format_table (copy_into .target )
976986 if isinstance (copy_into .target , (TableClause ,))
@@ -1090,8 +1100,21 @@ def visit_google_cloud_storage(self, gcs: GoogleCloudStorage, **kw):
10901100 f")"
10911101 )
10921102
1103+ def visit_stage (self , stage , ** kw ):
1104+ if stage .path :
1105+ return f"@{ stage .name } /{ stage .path } "
1106+ return f"@{ stage .name } "
1107+
10931108
10941109class DatabendExecutionContext (default .DefaultExecutionContext ):
1110+ iscopyintotable = False
1111+ iscopyintolocation = False
1112+
1113+ _copy_input_bytes : Optional [int ] = None
1114+ _copy_output_bytes : Optional [int ] = None
1115+ _copy_into_table_results : Optional [list [dict ]] = None
1116+ _copy_into_location_results : dict = None
1117+
10951118 @sa_util .memoized_property
10961119 def should_autocommit (self ):
10971120 return False # No DML supported, never autocommit
@@ -1102,6 +1125,38 @@ def create_server_side_cursor(self):
11021125 def create_default_cursor (self ):
11031126 return self ._dbapi_connection .cursor ()
11041127
1128+ def post_exec (self ):
1129+ self .iscopyintotable = getattr (self .compiled , 'iscopyintotable' , False )
1130+ self .iscopyintolocation = getattr (self .compiled , 'iscopyintolocation' , False )
1131+ if (self .isinsert or self .isupdate or self .isdelete or
1132+ self .iscopyintolocation or self .iscopyintotable ):
1133+ result = self .cursor .fetchall ()
1134+ if self .iscopyintotable :
1135+ self ._copy_into_table_results = [
1136+ {
1137+ 'file' : row [0 ],
1138+ 'rows_loaded' : row [1 ],
1139+ 'errors_seen' : row [2 ],
1140+ 'first_error' : row [3 ],
1141+ 'first_error_line' : row [4 ],
1142+ } for row in result
1143+ ]
1144+ self ._rowcount = sum (c ['rows_loaded' ] for c in self ._copy_into_table_results )
1145+ else :
1146+ self ._rowcount = result [0 ][0 ]
1147+ if self .iscopyintolocation :
1148+ self ._copy_into_location_results = {
1149+ 'rows_unloaded' : result [0 ][0 ],
1150+ 'input_bytes' : result [0 ][1 ],
1151+ 'output_bytes' : result [0 ][2 ],
1152+ }
1153+
1154+ def copy_into_table_results (self ) -> list [dict ]:
1155+ return self ._copy_into_table_results
1156+
1157+ def copy_into_location_results (self ) -> dict :
1158+ return self ._copy_into_location_results
1159+
11051160
11061161class DatabendTypeCompiler (compiler .GenericTypeCompiler ):
11071162 def visit_ARRAY (self , type_ , ** kw ):
@@ -1171,6 +1226,12 @@ def post_create_table(self, table):
11711226 if engine is not None :
11721227 table_opts .append (f" ENGINE={ engine } " )
11731228
1229+ if table .comment is not None :
1230+ comment = self .sql_compiler .render_literal_value (
1231+ table .comment , sqltypes .String ()
1232+ )
1233+ table_opts .append (f" COMMENT={ comment } " )
1234+
11741235 cluster_keys = db_opts .get ("cluster_by" )
11751236 if cluster_keys is not None :
11761237 if isinstance (cluster_keys , str ):
@@ -1192,6 +1253,37 @@ def post_create_table(self, table):
11921253
11931254 return " " .join (table_opts )
11941255
1256+ def get_column_specification (self , column , ** kwargs ):
1257+ colspec = super ().get_column_specification (column , ** kwargs )
1258+ comment = column .comment
1259+ if comment is not None :
1260+ literal = self .sql_compiler .render_literal_value (
1261+ comment , sqltypes .String ()
1262+ )
1263+ colspec += " COMMENT " + literal
1264+
1265+ return colspec
1266+
1267+ def visit_set_table_comment (self , create , ** kw ):
1268+ return "ALTER TABLE %s COMMENT = %s" % (
1269+ self .preparer .format_table (create .element ),
1270+ self .sql_compiler .render_literal_value (
1271+ create .element .comment , sqltypes .String ()
1272+ ),
1273+ )
1274+
1275+ def visit_drop_table_comment (self , create , ** kw ):
1276+ return "ALTER TABLE %s COMMENT = ''" % (
1277+ self .preparer .format_table (create .element )
1278+ )
1279+
1280+ def visit_set_column_comment (self , create , ** kw ):
1281+ return "ALTER TABLE %s MODIFY %s %s" % (
1282+ self .preparer .format_table (create .element .table ),
1283+ self .preparer .format_column (create .element ),
1284+ self .get_column_specification (create .element ),
1285+ )
1286+
11951287
11961288class DatabendDialect (default .DefaultDialect ):
11971289 name = "databend"
@@ -1204,7 +1296,7 @@ class DatabendDialect(default.DefaultDialect):
12041296 supports_alter = True
12051297 supports_comments = False
12061298 supports_empty_insert = False
1207- supports_is_distinct_from = False
1299+ supports_is_distinct_from = True
12081300 supports_multivalues_insert = True
12091301
12101302 supports_statement_cache = False
@@ -1316,7 +1408,7 @@ def has_table(self, connection, table_name, schema=None, **kw):
13161408 def get_columns (self , connection , table_name , schema = None , ** kw ):
13171409 query = text (
13181410 """
1319- select column_name, column_type, is_nullable
1411+ select column_name, column_type, is_nullable, nullif(column_comment, '')
13201412 from information_schema.columns
13211413 where table_name = :table_name
13221414 and table_schema = :schema_name
@@ -1337,6 +1429,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
13371429 "type" : self ._get_column_type (row [1 ]),
13381430 "nullable" : get_is_nullable (row [2 ]),
13391431 "default" : None ,
1432+ "comment" : row [3 ],
13401433 }
13411434 for row in result
13421435 ]
@@ -1416,6 +1509,23 @@ def get_table_names(self, connection, schema=None, **kw):
14161509 result = connection .execute (query , dict (schema_name = schema ))
14171510 return [row [0 ] for row in result ]
14181511
1512+ @reflection .cache
1513+ def get_temp_table_names (self , connection , schema = None , ** kw ):
1514+ table_name_query = """
1515+ select name
1516+ from system.temporary_tables
1517+ where database = :schema_name
1518+ """
1519+ query = text (table_name_query ).bindparams (
1520+ bindparam ("schema_name" , type_ = sqltypes .Unicode )
1521+ )
1522+ if schema is None :
1523+ schema = self .default_schema_name
1524+
1525+ result = connection .execute (query , dict (schema_name = schema ))
1526+ return [row [0 ] for row in result ]
1527+
1528+
14191529 @reflection .cache
14201530 def get_view_names (self , connection , schema = None , ** kw ):
14211531 view_name_query = """
@@ -1510,6 +1620,82 @@ def get_table_options(self, connection, table_name, schema=None, **kw):
15101620
15111621 return options
15121622
1623+ @reflection .cache
1624+ def get_table_comment (self , connection , table_name , schema , ** kw ):
1625+ query_text = """
1626+ SELECT comment
1627+ FROM system.tables
1628+ WHERE database = :schema_name
1629+ and name = :table_name
1630+ """
1631+ query = text (query_text ).bindparams (
1632+ bindparam ("table_name" , type_ = sqltypes .Unicode ),
1633+ bindparam ("schema_name" , type_ = sqltypes .Unicode ),
1634+ )
1635+ if schema is None :
1636+ schema = self .default_schema_name
1637+
1638+ result = connection .execute (
1639+ query , dict (table_name = table_name , schema_name = schema )
1640+ ).one_or_none ()
1641+ if not result :
1642+ raise NoSuchTableError (
1643+ f"{ self .identifier_preparer .quote_identifier (schema )} ."
1644+ f"{ self .identifier_preparer .quote_identifier (table_name )} "
1645+ )
1646+ return {'text' : result [0 ]} if result [0 ] else reflection .ReflectionDefaults .table_comment () if hasattr (reflection , 'ReflectionDefault' ) else {'text' : None }
1647+
1648+ def _prepare_filter_names (self , filter_names ):
1649+ if filter_names :
1650+ fn = [name for name in filter_names ]
1651+ return True , {"filter_names" : fn }
1652+ else :
1653+ return False , {}
1654+
1655+ def get_multi_table_comment (
1656+ self , connection , schema , filter_names , scope , kind , ** kw
1657+ ):
1658+ meta = MetaData ()
1659+ all_tab_comments = Table (
1660+ "tables" ,
1661+ meta ,
1662+ Column ("database" , VARCHAR , nullable = False ),
1663+ Column ("name" , VARCHAR , nullable = False ),
1664+ Column ("comment" , VARCHAR ),
1665+ Column ("table_type" , VARCHAR ),
1666+ schema = 'system' ,
1667+ ).alias ("a_tab_comments" )
1668+
1669+
1670+ has_filter_names , params = self ._prepare_filter_names (filter_names )
1671+ owner = schema or self .default_schema_name
1672+
1673+ table_types = set ()
1674+ if reflection .ObjectKind .TABLE in kind :
1675+ table_types .add ('BASE TABLE' )
1676+ if reflection .ObjectKind .VIEW in kind :
1677+ table_types .add ('VIEW' )
1678+
1679+ query = select (
1680+ all_tab_comments .c .name , all_tab_comments .c .comment
1681+ ).where (
1682+ all_tab_comments .c .database == owner ,
1683+ all_tab_comments .c .table_type .in_ (table_types ),
1684+ sqlalchemy .true () if reflection .ObjectScope .DEFAULT in scope else sqlalchemy .false (),
1685+ )
1686+ if has_filter_names :
1687+ query = query .where (all_tab_comments .c .name .in_ (bindparam ("filter_names" )))
1688+
1689+ result = connection .execute (query , params )
1690+ default_comment = reflection .ReflectionDefaults .table_comment
1691+ return (
1692+ (
1693+ (schema , table ),
1694+ {"text" : comment } if comment else default_comment (),
1695+ )
1696+ for table , comment in result
1697+ )
1698+
15131699 def do_rollback (self , dbapi_connection ):
15141700 # No transactions
15151701 pass
0 commit comments