2828import re
2929import operator
3030import datetime
31+ from types import NoneType
32+
3133import sqlalchemy .types as sqltypes
3234from typing import Any , Dict , Optional , Union
3335from sqlalchemy import util as sa_util
5961from sqlalchemy .engine import ExecutionContext , default
6062from sqlalchemy .exc import DBAPIError , NoSuchTableError
6163
62- from .dml import Merge
64+ from .dml import (
65+ Merge ,
66+ StageClause ,
67+ _StorageClause ,
68+ GoogleCloudStorage ,
69+ AzureBlobStorage ,
70+ AmazonS3 ,
71+ )
6372from .types import INTERVAL
6473
6574RESERVED_WORDS = {
@@ -897,13 +906,17 @@ def visit_merge(self, merge, **kw):
897906 )
898907 elif isinstance (merge .source , Subquery ):
899908 source = merge .source ._compiler_dispatch (self , ** source_kw )
909+ else :
910+ source = merge .source
911+
912+ merge_on = merge .on ._compiler_dispatch (self , ** kw )
900913
901914 target_table = self .preparer .format_table (merge .target )
902915 return (
903916 f"MERGE INTO { target_table } \n "
904917 f" USING { source } \n "
905- f" ON { merge . on } \n "
906- f"{ clauses if clauses else '' } "
918+ f" ON { merge_on } \n "
919+ f" { clauses if clauses else '' } "
907920 )
908921
909922 def visit_when_merge_matched_update (self , merge_matched_update , ** kw ):
@@ -912,7 +925,7 @@ def visit_when_merge_matched_update(self, merge_matched_update, **kw):
912925 if merge_matched_update .predicate is not None
913926 else ""
914927 )
915- update_str = f"WHEN MATCHED{ case_predicate } THEN\n " f" \t UPDATE "
928+ update_str = f"WHEN MATCHED{ case_predicate } THEN\n UPDATE "
916929 if not merge_matched_update .set :
917930 return f"{ update_str } *"
918931
@@ -941,7 +954,7 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw):
941954 if merge_unmatched .predicate is not None
942955 else ""
943956 )
944- insert_str = f"WHEN NOT MATCHED{ case_predicate } THEN\n " f" \t INSERT "
957+ insert_str = f"WHEN NOT MATCHED{ case_predicate } THEN\n INSERT "
945958 if not merge_unmatched .set :
946959 return f"{ insert_str } *"
947960
@@ -957,6 +970,126 @@ def visit_when_merge_unmatched(self, merge_unmatched, **kw):
957970 ", " .join (map (lambda e : e ._compiler_dispatch (self , ** kw ), sets_vals )),
958971 )
959972
973+ def visit_copy_into (self , copy_into , ** kw ):
974+ target = (
975+ self .preparer .format_table (copy_into .target )
976+ if isinstance (copy_into .target , (TableClause ,))
977+ else copy_into .target ._compiler_dispatch (self , ** kw )
978+ )
979+
980+ if isinstance (copy_into .from_ , (TableClause ,)):
981+ source = self .preparer .format_table (copy_into .from_ )
982+ elif isinstance (copy_into .from_ , (_StorageClause , StageClause )):
983+ source = copy_into .from_ ._compiler_dispatch (self , ** kw )
984+ # elif isinstance(copy_into.from_, (FileColumnClause)):
985+ # source = f"({copy_into.from_._compiler_dispatch(self, **kw)})"
986+ else :
987+ source = f"({ copy_into .from_ ._compiler_dispatch (self , ** kw )} )"
988+
989+ result = f"COPY INTO { target } " f" FROM { source } "
990+ if hasattr (copy_into , "files" ) and isinstance (copy_into .files , list ):
991+ result += f"FILES = { ', ' .join ([f for f in copy_into .files ])} "
992+ if hasattr (copy_into , "pattern" ) and copy_into .pattern :
993+ result += f" PATTERN = '{ copy_into .pattern } '"
994+ if not isinstance (copy_into .file_format , NoneType ):
995+ result += f" { copy_into .file_format ._compiler_dispatch (self , ** kw )} \n "
996+ if not isinstance (copy_into .options , NoneType ):
997+ result += f" { copy_into .options ._compiler_dispatch (self , ** kw )} \n "
998+
999+ return result
1000+
1001+ def visit_copy_format (self , file_format , ** kw ):
1002+ options_list = list (file_format .options .items ())
1003+ if kw .get ("deterministic" , False ):
1004+ options_list .sort (key = operator .itemgetter (0 ))
1005+ # predefined format name
1006+ if "format_name" in file_format .options :
1007+ return f"FILE_FORMAT=(format_name = { file_format .options ['format_name' ]} )"
1008+ # format specifics
1009+ format_options = [f"TYPE = { file_format .format_type } " ]
1010+ format_options .extend (
1011+ [
1012+ "{} = {}" .format (
1013+ option ,
1014+ (
1015+ value ._compiler_dispatch (self , ** kw )
1016+ if hasattr (value , "_compiler_dispatch" )
1017+ else str (value )
1018+ ),
1019+ )
1020+ for option , value in options_list
1021+ ]
1022+ )
1023+ return f"FILE_FORMAT = ({ ', ' .join (format_options )} )"
1024+
1025+ def visit_copy_into_options (self , copy_into_options , ** kw ):
1026+ options_list = list (copy_into_options .options .items ())
1027+ # if kw.get("deterministic", False):
1028+ # options_list.sort(key=operator.itemgetter(0))
1029+ return "\n " .join ([f"{ k } = { v } " for k , v in options_list ])
1030+
1031+ def visit_file_column (self , file_column_clause , ** kw ):
1032+ if isinstance (file_column_clause .from_ , (TableClause ,)):
1033+ source = self .preparer .format_table (file_column_clause .from_ )
1034+ elif isinstance (file_column_clause .from_ , (_StorageClause , StageClause )):
1035+ source = file_column_clause .from_ ._compiler_dispatch (self , ** kw )
1036+ else :
1037+ source = f"({ file_column_clause .from_ ._compiler_dispatch (self , ** kw )} )"
1038+ if isinstance (file_column_clause .columns , str ):
1039+ select_str = file_column_clause .columns
1040+ else :
1041+ select_str = "," .join (
1042+ [
1043+ col ._compiler_dispatch (self , ** kw )
1044+ for col in file_column_clause .columns
1045+ ]
1046+ )
1047+ return f"SELECT { select_str } " f" FROM { source } "
1048+
1049+ def visit_amazon_s3 (self , amazon_s3 : AmazonS3 , ** kw ):
1050+ connection_params_str = f" ACCESS_KEY_ID = '{ amazon_s3 .access_key_id } ' \n "
1051+ connection_params_str += (
1052+ f" SECRET_ACCESS_KEY = '{ amazon_s3 .secret_access_key } '\n "
1053+ )
1054+ if amazon_s3 .endpoint_url :
1055+ connection_params_str += f" ENDPOINT_URL = '{ amazon_s3 .endpoint_url } ' \n "
1056+ if amazon_s3 .enable_virtual_host_style :
1057+ connection_params_str += f" ENABLE_VIRTUAL_HOST_STYLE = '{ amazon_s3 .enable_virtual_host_style } '\n "
1058+ if amazon_s3 .master_key :
1059+ connection_params_str += f" MASTER_KEY = '{ amazon_s3 .master_key } '\n "
1060+ if amazon_s3 .region :
1061+ connection_params_str += f" REGION = '{ amazon_s3 .region } '\n "
1062+ if amazon_s3 .security_token :
1063+ connection_params_str += (
1064+ f" SECURITY_TOKEN = '{ amazon_s3 .security_token } '\n "
1065+ )
1066+
1067+ return (
1068+ f"'{ amazon_s3 .uri } ' \n "
1069+ f"CONNECTION = (\n "
1070+ f"{ connection_params_str } \n "
1071+ f")"
1072+ )
1073+
1074+ def visit_azure_blob_storage (self , azure : AzureBlobStorage , ** kw ):
1075+ return (
1076+ f"'{ azure .uri } ' \n "
1077+ f"CONNECTION = (\n "
1078+ f" ENDPOINT_URL = 'https://{ azure .account_name } .blob.core.windows.net' \n "
1079+ f" ACCOUNT_NAME = '{ azure .account_name } ' \n "
1080+ f" ACCOUNT_KEY = '{ azure .account_key } '\n "
1081+ f")"
1082+ )
1083+
1084+ def visit_google_cloud_storage (self , gcs : GoogleCloudStorage , ** kw ):
1085+ return (
1086+ f"'{ gcs .uri } ' \n "
1087+ f"CONNECTION = (\n "
1088+ f" ENDPOINT_URL = 'https://storage.googleapis.com' \n "
1089+ f" CREDENTIAL = '{ gcs .credentials } ' \n "
1090+ f")"
1091+ )
1092+
9601093
9611094class DatabendExecutionContext (default .DefaultExecutionContext ):
9621095 @sa_util .memoized_property
0 commit comments