sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.errors import UnsupportedError 7from sqlglot.helper import find_new_name, name_sequence 8 9 10if t.TYPE_CHECKING: 11 from sqlglot.generator import Generator 12 13 14def preprocess( 15 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 16) -> t.Callable[[Generator, exp.Expression], str]: 17 """ 18 Creates a new transform by chaining a sequence of transformations and converts the resulting 19 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 20 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 21 22 Args: 23 transforms: sequence of transform functions. These will be called in order. 24 25 Returns: 26 Function that can be used as a generator transform. 27 """ 28 29 def _to_sql(self, expression: exp.Expression) -> str: 30 expression_type = type(expression) 31 32 try: 33 expression = transforms[0](expression) 34 for transform in transforms[1:]: 35 expression = transform(expression) 36 except UnsupportedError as unsupported_error: 37 self.unsupported(str(unsupported_error)) 38 39 _sql_handler = getattr(self, expression.key + "_sql", None) 40 if _sql_handler: 41 return _sql_handler(expression) 42 43 transforms_handler = self.TRANSFORMS.get(type(expression)) 44 if transforms_handler: 45 if expression_type is type(expression): 46 if isinstance(expression, exp.Func): 47 return self.function_fallback_sql(expression) 48 49 # Ensures we don't enter an infinite loop. This can happen when the original expression 50 # has the same type as the final expression and there's no _sql method available for it, 51 # because then it'd re-enter _to_sql. 52 raise ValueError( 53 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 54 ) 55 56 return transforms_handler(self, expression) 57 58 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 59 60 return _to_sql 61 62 63def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 64 if isinstance(expression, exp.Select): 65 count = 0 66 recursive_ctes = [] 67 68 for unnest in expression.find_all(exp.Unnest): 69 if ( 70 not isinstance(unnest.parent, (exp.From, exp.Join)) 71 or len(unnest.expressions) != 1 72 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 73 ): 74 continue 75 76 generate_date_array = unnest.expressions[0] 77 start = generate_date_array.args.get("start") 78 end = generate_date_array.args.get("end") 79 step = generate_date_array.args.get("step") 80 81 if not start or not end or not isinstance(step, exp.Interval): 82 continue 83 84 alias = unnest.args.get("alias") 85 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 86 87 start = exp.cast(start, "date") 88 date_add = exp.func( 89 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 90 ) 91 cast_date_add = exp.cast(date_add, "date") 92 93 cte_name = "_generated_dates" + (f"_{count}" if count else "") 94 95 base_query = exp.select(start.as_(column_name)) 96 recursive_query = ( 97 exp.select(cast_date_add) 98 .from_(cte_name) 99 .where(cast_date_add <= exp.cast(end, "date")) 100 ) 101 cte_query = base_query.union(recursive_query, distinct=False) 102 103 generate_dates_query = exp.select(column_name).from_(cte_name) 104 unnest.replace(generate_dates_query.subquery(cte_name)) 105 106 recursive_ctes.append( 107 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 108 ) 109 count += 1 110 111 if recursive_ctes: 112 with_expression = expression.args.get("with") or exp.With() 113 with_expression.set("recursive", True) 114 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 115 expression.set("with", with_expression) 116 117 return expression 118 119 120def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 121 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 122 this = expression.this 123 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 124 unnest = exp.Unnest(expressions=[this]) 125 if expression.alias: 126 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 127 128 return unnest 129 130 return expression 131 132 133def unalias_group(expression: exp.Expression) -> exp.Expression: 134 """ 135 Replace references to select aliases in GROUP BY clauses. 136 137 Example: 138 >>> import sqlglot 139 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 140 'SELECT a AS b FROM x GROUP BY 1' 141 142 Args: 143 expression: the expression that will be transformed. 144 145 Returns: 146 The transformed expression. 147 """ 148 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 149 aliased_selects = { 150 e.alias: i 151 for i, e in enumerate(expression.parent.expressions, start=1) 152 if isinstance(e, exp.Alias) 153 } 154 155 for group_by in expression.expressions: 156 if ( 157 isinstance(group_by, exp.Column) 158 and not group_by.table 159 and group_by.name in aliased_selects 160 ): 161 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 162 163 return expression 164 165 166def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 167 """ 168 Convert SELECT DISTINCT ON statements to a subquery with a window function. 169 170 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 171 172 Args: 173 expression: the expression that will be transformed. 174 175 Returns: 176 The transformed expression. 177 """ 178 if ( 179 isinstance(expression, exp.Select) 180 and expression.args.get("distinct") 181 and expression.args["distinct"].args.get("on") 182 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 183 ): 184 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 185 outer_selects = expression.selects 186 row_number = find_new_name(expression.named_selects, "_row_number") 187 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 188 order = expression.args.get("order") 189 190 if order: 191 window.set("order", order.pop()) 192 else: 193 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 194 195 window = exp.alias_(window, row_number) 196 expression.select(window, copy=False) 197 198 return ( 199 exp.select(*outer_selects, copy=False) 200 .from_(expression.subquery("_t", copy=False), copy=False) 201 .where(exp.column(row_number).eq(1), copy=False) 202 ) 203 204 return expression 205 206 207def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 208 """ 209 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 210 211 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 212 https://docs.snowflake.com/en/sql-reference/constructs/qualify 213 214 Some dialects don't support window functions in the WHERE clause, so we need to include them as 215 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 216 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 217 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 218 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 219 corresponding expression to avoid creating invalid column references. 220 """ 221 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 222 taken = set(expression.named_selects) 223 for select in expression.selects: 224 if not select.alias_or_name: 225 alias = find_new_name(taken, "_c") 226 select.replace(exp.alias_(select, alias)) 227 taken.add(alias) 228 229 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 230 alias_or_name = select.alias_or_name 231 identifier = select.args.get("alias") or select.this 232 if isinstance(identifier, exp.Identifier): 233 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 234 return alias_or_name 235 236 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 237 qualify_filters = expression.args["qualify"].pop().this 238 expression_by_alias = { 239 select.alias: select.this 240 for select in expression.selects 241 if isinstance(select, exp.Alias) 242 } 243 244 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 245 for select_candidate in qualify_filters.find_all(select_candidates): 246 if isinstance(select_candidate, exp.Window): 247 if expression_by_alias: 248 for column in select_candidate.find_all(exp.Column): 249 expr = expression_by_alias.get(column.name) 250 if expr: 251 column.replace(expr) 252 253 alias = find_new_name(expression.named_selects, "_w") 254 expression.select(exp.alias_(select_candidate, alias), copy=False) 255 column = exp.column(alias) 256 257 if isinstance(select_candidate.parent, exp.Qualify): 258 qualify_filters = column 259 else: 260 select_candidate.replace(column) 261 elif select_candidate.name not in expression.named_selects: 262 expression.select(select_candidate.copy(), copy=False) 263 264 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 265 qualify_filters, copy=False 266 ) 267 268 return expression 269 270 271def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 272 """ 273 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 274 other expressions. This transforms removes the precision from parameterized types in expressions. 275 """ 276 for node in expression.find_all(exp.DataType): 277 node.set( 278 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 279 ) 280 281 return expression 282 283 284def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 285 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 286 from sqlglot.optimizer.scope import find_all_in_scope 287 288 if isinstance(expression, exp.Select): 289 unnest_aliases = { 290 unnest.alias 291 for unnest in find_all_in_scope(expression, exp.Unnest) 292 if isinstance(unnest.parent, (exp.From, exp.Join)) 293 } 294 if unnest_aliases: 295 for column in expression.find_all(exp.Column): 296 if column.table in unnest_aliases: 297 column.set("table", None) 298 elif column.db in unnest_aliases: 299 column.set("db", None) 300 301 return expression 302 303 304def unnest_to_explode( 305 expression: exp.Expression, 306 unnest_using_arrays_zip: bool = True, 307) -> exp.Expression: 308 """Convert cross join unnest into lateral view explode.""" 309 310 def _unnest_zip_exprs( 311 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 312 ) -> t.List[exp.Expression]: 313 if has_multi_expr: 314 if not unnest_using_arrays_zip: 315 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 316 317 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 318 zip_exprs: t.List[exp.Expression] = [ 319 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 320 ] 321 u.set("expressions", zip_exprs) 322 return zip_exprs 323 return unnest_exprs 324 325 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 326 if u.args.get("offset"): 327 return exp.Posexplode 328 return exp.Inline if has_multi_expr else exp.Explode 329 330 if isinstance(expression, exp.Select): 331 from_ = expression.args.get("from") 332 333 if from_ and isinstance(from_.this, exp.Unnest): 334 unnest = from_.this 335 alias = unnest.args.get("alias") 336 exprs = unnest.expressions 337 has_multi_expr = len(exprs) > 1 338 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 339 340 unnest.replace( 341 exp.Table( 342 this=_udtf_type(unnest, has_multi_expr)( 343 this=this, 344 expressions=expressions, 345 ), 346 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 347 ) 348 ) 349 350 for join in expression.args.get("joins") or []: 351 join_expr = join.this 352 353 is_lateral = isinstance(join_expr, exp.Lateral) 354 355 unnest = join_expr.this if is_lateral else join_expr 356 357 if isinstance(unnest, exp.Unnest): 358 if is_lateral: 359 alias = join_expr.args.get("alias") 360 else: 361 alias = unnest.args.get("alias") 362 exprs = unnest.expressions 363 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 364 has_multi_expr = len(exprs) > 1 365 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 366 367 expression.args["joins"].remove(join) 368 369 alias_cols = alias.columns if alias else [] 370 for e, column in zip(exprs, alias_cols): 371 expression.append( 372 "laterals", 373 exp.Lateral( 374 this=_udtf_type(unnest, has_multi_expr)(this=e), 375 view=True, 376 alias=exp.TableAlias( 377 this=alias.this, # type: ignore 378 columns=alias_cols if unnest_using_arrays_zip else [column], # type: ignore 379 ), 380 ), 381 ) 382 383 return expression 384 385 386def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 387 """Convert explode/posexplode into unnest.""" 388 389 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 390 if isinstance(expression, exp.Select): 391 from sqlglot.optimizer.scope import Scope 392 393 taken_select_names = set(expression.named_selects) 394 taken_source_names = {name for name, _ in Scope(expression).references} 395 396 def new_name(names: t.Set[str], name: str) -> str: 397 name = find_new_name(names, name) 398 names.add(name) 399 return name 400 401 arrays: t.List[exp.Condition] = [] 402 series_alias = new_name(taken_select_names, "pos") 403 series = exp.alias_( 404 exp.Unnest( 405 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 406 ), 407 new_name(taken_source_names, "_u"), 408 table=[series_alias], 409 ) 410 411 # we use list here because expression.selects is mutated inside the loop 412 for select in list(expression.selects): 413 explode = select.find(exp.Explode) 414 415 if explode: 416 pos_alias = "" 417 explode_alias = "" 418 419 if isinstance(select, exp.Alias): 420 explode_alias = select.args["alias"] 421 alias = select 422 elif isinstance(select, exp.Aliases): 423 pos_alias = select.aliases[0] 424 explode_alias = select.aliases[1] 425 alias = select.replace(exp.alias_(select.this, "", copy=False)) 426 else: 427 alias = select.replace(exp.alias_(select, "")) 428 explode = alias.find(exp.Explode) 429 assert explode 430 431 is_posexplode = isinstance(explode, exp.Posexplode) 432 explode_arg = explode.this 433 434 if isinstance(explode, exp.ExplodeOuter): 435 bracket = explode_arg[0] 436 bracket.set("safe", True) 437 bracket.set("offset", True) 438 explode_arg = exp.func( 439 "IF", 440 exp.func( 441 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 442 ).eq(0), 443 exp.array(bracket, copy=False), 444 explode_arg, 445 ) 446 447 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 448 if isinstance(explode_arg, exp.Column): 449 taken_select_names.add(explode_arg.output_name) 450 451 unnest_source_alias = new_name(taken_source_names, "_u") 452 453 if not explode_alias: 454 explode_alias = new_name(taken_select_names, "col") 455 456 if is_posexplode: 457 pos_alias = new_name(taken_select_names, "pos") 458 459 if not pos_alias: 460 pos_alias = new_name(taken_select_names, "pos") 461 462 alias.set("alias", exp.to_identifier(explode_alias)) 463 464 series_table_alias = series.args["alias"].this 465 column = exp.If( 466 this=exp.column(series_alias, table=series_table_alias).eq( 467 exp.column(pos_alias, table=unnest_source_alias) 468 ), 469 true=exp.column(explode_alias, table=unnest_source_alias), 470 ) 471 472 explode.replace(column) 473 474 if is_posexplode: 475 expressions = expression.expressions 476 expressions.insert( 477 expressions.index(alias) + 1, 478 exp.If( 479 this=exp.column(series_alias, table=series_table_alias).eq( 480 exp.column(pos_alias, table=unnest_source_alias) 481 ), 482 true=exp.column(pos_alias, table=unnest_source_alias), 483 ).as_(pos_alias), 484 ) 485 expression.set("expressions", expressions) 486 487 if not arrays: 488 if expression.args.get("from"): 489 expression.join(series, copy=False, join_type="CROSS") 490 else: 491 expression.from_(series, copy=False) 492 493 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 494 arrays.append(size) 495 496 # trino doesn't support left join unnest with on conditions 497 # if it did, this would be much simpler 498 expression.join( 499 exp.alias_( 500 exp.Unnest( 501 expressions=[explode_arg.copy()], 502 offset=exp.to_identifier(pos_alias), 503 ), 504 unnest_source_alias, 505 table=[explode_alias], 506 ), 507 join_type="CROSS", 508 copy=False, 509 ) 510 511 if index_offset != 1: 512 size = size - 1 513 514 expression.where( 515 exp.column(series_alias, table=series_table_alias) 516 .eq(exp.column(pos_alias, table=unnest_source_alias)) 517 .or_( 518 (exp.column(series_alias, table=series_table_alias) > size).and_( 519 exp.column(pos_alias, table=unnest_source_alias).eq(size) 520 ) 521 ), 522 copy=False, 523 ) 524 525 if arrays: 526 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 527 528 if index_offset != 1: 529 end = end - (1 - index_offset) 530 series.expressions[0].set("end", end) 531 532 return expression 533 534 return _explode_to_unnest 535 536 537def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 538 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 539 if ( 540 isinstance(expression, exp.PERCENTILES) 541 and not isinstance(expression.parent, exp.WithinGroup) 542 and expression.expression 543 ): 544 column = expression.this.pop() 545 expression.set("this", expression.expression.pop()) 546 order = exp.Order(expressions=[exp.Ordered(this=column)]) 547 expression = exp.WithinGroup(this=expression, expression=order) 548 549 return expression 550 551 552def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 553 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 554 if ( 555 isinstance(expression, exp.WithinGroup) 556 and isinstance(expression.this, exp.PERCENTILES) 557 and isinstance(expression.expression, exp.Order) 558 ): 559 quantile = expression.this.this 560 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 561 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 562 563 return expression 564 565 566def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 567 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 568 if isinstance(expression, exp.With) and expression.recursive: 569 next_name = name_sequence("_c_") 570 571 for cte in expression.expressions: 572 if not cte.args["alias"].columns: 573 query = cte.this 574 if isinstance(query, exp.SetOperation): 575 query = query.this 576 577 cte.args["alias"].set( 578 "columns", 579 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 580 ) 581 582 return expression 583 584 585def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 586 """Replace 'epoch' in casts by the equivalent date literal.""" 587 if ( 588 isinstance(expression, (exp.Cast, exp.TryCast)) 589 and expression.name.lower() == "epoch" 590 and expression.to.this in exp.DataType.TEMPORAL_TYPES 591 ): 592 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 593 594 return expression 595 596 597def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 598 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 599 if isinstance(expression, exp.Select): 600 for join in expression.args.get("joins") or []: 601 on = join.args.get("on") 602 if on and join.kind in ("SEMI", "ANTI"): 603 subquery = exp.select("1").from_(join.this).where(on) 604 exists = exp.Exists(this=subquery) 605 if join.kind == "ANTI": 606 exists = exists.not_(copy=False) 607 608 join.pop() 609 expression.where(exists, copy=False) 610 611 return expression 612 613 614def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 615 """ 616 Converts a query with a FULL OUTER join to a union of identical queries that 617 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 618 for queries that have a single FULL OUTER join. 619 """ 620 if isinstance(expression, exp.Select): 621 full_outer_joins = [ 622 (index, join) 623 for index, join in enumerate(expression.args.get("joins") or []) 624 if join.side == "FULL" 625 ] 626 627 if len(full_outer_joins) == 1: 628 expression_copy = expression.copy() 629 expression.set("limit", None) 630 index, full_outer_join = full_outer_joins[0] 631 632 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 633 join_conditions = full_outer_join.args.get("on") or exp.and_( 634 *[ 635 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 636 for col in full_outer_join.args.get("using") 637 ] 638 ) 639 640 full_outer_join.set("side", "left") 641 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 642 expression_copy.args["joins"][index].set("side", "right") 643 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 644 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 645 expression.args.pop("order", None) # remove order by from LEFT side 646 647 return exp.union(expression, expression_copy, copy=False, distinct=False) 648 649 return expression 650 651 652def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 653 """ 654 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 655 defined at the top-level, so for example queries like: 656 657 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 658 659 are invalid in those dialects. This transformation can be used to ensure all CTEs are 660 moved to the top level so that the final SQL code is valid from a syntax standpoint. 661 662 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 663 """ 664 top_level_with = expression.args.get("with") 665 for inner_with in expression.find_all(exp.With): 666 if inner_with.parent is expression: 667 continue 668 669 if not top_level_with: 670 top_level_with = inner_with.pop() 671 expression.set("with", top_level_with) 672 else: 673 if inner_with.recursive: 674 top_level_with.set("recursive", True) 675 676 parent_cte = inner_with.find_ancestor(exp.CTE) 677 inner_with.pop() 678 679 if parent_cte: 680 i = top_level_with.expressions.index(parent_cte) 681 top_level_with.expressions[i:i] = inner_with.expressions 682 top_level_with.set("expressions", top_level_with.expressions) 683 else: 684 top_level_with.set( 685 "expressions", top_level_with.expressions + inner_with.expressions 686 ) 687 688 return expression 689 690 691def ensure_bools(expression: exp.Expression) -> exp.Expression: 692 """Converts numeric values used in conditions into explicit boolean expressions.""" 693 from sqlglot.optimizer.canonicalize import ensure_bools 694 695 def _ensure_bool(node: exp.Expression) -> None: 696 if ( 697 node.is_number 698 or ( 699 not isinstance(node, exp.SubqueryPredicate) 700 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 701 ) 702 or (isinstance(node, exp.Column) and not node.type) 703 ): 704 node.replace(node.neq(0)) 705 706 for node in expression.walk(): 707 ensure_bools(node, _ensure_bool) 708 709 return expression 710 711 712def unqualify_columns(expression: exp.Expression) -> exp.Expression: 713 for column in expression.find_all(exp.Column): 714 # We only wanna pop off the table, db, catalog args 715 for part in column.parts[:-1]: 716 part.pop() 717 718 return expression 719 720 721def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 722 assert isinstance(expression, exp.Create) 723 for constraint in expression.find_all(exp.UniqueColumnConstraint): 724 if constraint.parent: 725 constraint.parent.pop() 726 727 return expression 728 729 730def ctas_with_tmp_tables_to_create_tmp_view( 731 expression: exp.Expression, 732 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 733) -> exp.Expression: 734 assert isinstance(expression, exp.Create) 735 properties = expression.args.get("properties") 736 temporary = any( 737 isinstance(prop, exp.TemporaryProperty) 738 for prop in (properties.expressions if properties else []) 739 ) 740 741 # CTAS with temp tables map to CREATE TEMPORARY VIEW 742 if expression.kind == "TABLE" and temporary: 743 if expression.expression: 744 return exp.Create( 745 kind="TEMPORARY VIEW", 746 this=expression.this, 747 expression=expression.expression, 748 ) 749 return tmp_storage_provider(expression) 750 751 return expression 752 753 754def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 755 """ 756 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 757 PARTITIONED BY value is an array of column names, they are transformed into a schema. 758 The corresponding columns are removed from the create statement. 759 """ 760 assert isinstance(expression, exp.Create) 761 has_schema = isinstance(expression.this, exp.Schema) 762 is_partitionable = expression.kind in {"TABLE", "VIEW"} 763 764 if has_schema and is_partitionable: 765 prop = expression.find(exp.PartitionedByProperty) 766 if prop and prop.this and not isinstance(prop.this, exp.Schema): 767 schema = expression.this 768 columns = {v.name.upper() for v in prop.this.expressions} 769 partitions = [col for col in schema.expressions if col.name.upper() in columns] 770 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 771 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 772 expression.set("this", schema) 773 774 return expression 775 776 777def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 778 """ 779 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 780 781 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 782 """ 783 assert isinstance(expression, exp.Create) 784 prop = expression.find(exp.PartitionedByProperty) 785 if ( 786 prop 787 and prop.this 788 and isinstance(prop.this, exp.Schema) 789 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 790 ): 791 prop_this = exp.Tuple( 792 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 793 ) 794 schema = expression.this 795 for e in prop.this.expressions: 796 schema.append("expressions", e) 797 prop.set("this", prop_this) 798 799 return expression 800 801 802def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 803 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 804 if isinstance(expression, exp.Struct): 805 expression.set( 806 "expressions", 807 [ 808 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 809 for e in expression.expressions 810 ], 811 ) 812 813 return expression 814 815 816def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 817 """ 818 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 819 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 820 821 For example, 822 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 823 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 824 825 Args: 826 expression: The AST to remove join marks from. 827 828 Returns: 829 The AST with join marks removed. 830 """ 831 from sqlglot.optimizer.scope import traverse_scope 832 833 for scope in traverse_scope(expression): 834 query = scope.expression 835 836 where = query.args.get("where") 837 joins = query.args.get("joins") 838 839 if not where or not joins: 840 continue 841 842 query_from = query.args["from"] 843 844 # These keep track of the joins to be replaced 845 new_joins: t.Dict[str, exp.Join] = {} 846 old_joins = {join.alias_or_name: join for join in joins} 847 848 for column in scope.columns: 849 if not column.args.get("join_mark"): 850 continue 851 852 predicate = column.find_ancestor(exp.Predicate, exp.Select) 853 assert isinstance( 854 predicate, exp.Binary 855 ), "Columns can only be marked with (+) when involved in a binary operation" 856 857 predicate_parent = predicate.parent 858 join_predicate = predicate.pop() 859 860 left_columns = [ 861 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 862 ] 863 right_columns = [ 864 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 865 ] 866 867 assert not ( 868 left_columns and right_columns 869 ), "The (+) marker cannot appear in both sides of a binary predicate" 870 871 marked_column_tables = set() 872 for col in left_columns or right_columns: 873 table = col.table 874 assert table, f"Column {col} needs to be qualified with a table" 875 876 col.set("join_mark", False) 877 marked_column_tables.add(table) 878 879 assert ( 880 len(marked_column_tables) == 1 881 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 882 883 join_this = old_joins.get(col.table, query_from).this 884 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 885 886 # Upsert new_join into new_joins dictionary 887 new_join_alias_or_name = new_join.alias_or_name 888 existing_join = new_joins.get(new_join_alias_or_name) 889 if existing_join: 890 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 891 else: 892 new_joins[new_join_alias_or_name] = new_join 893 894 # If the parent of the target predicate is a binary node, then it now has only one child 895 if isinstance(predicate_parent, exp.Binary): 896 if predicate_parent.left is None: 897 predicate_parent.replace(predicate_parent.right) 898 else: 899 predicate_parent.replace(predicate_parent.left) 900 901 if query_from.alias_or_name in new_joins: 902 only_old_joins = old_joins.keys() - new_joins.keys() 903 assert ( 904 len(only_old_joins) >= 1 905 ), "Cannot determine which table to use in the new FROM clause" 906 907 new_from_name = list(only_old_joins)[0] 908 query.set("from", exp.From(this=old_joins[new_from_name].this)) 909 910 query.set("joins", list(new_joins.values())) 911 912 if not where.this: 913 where.pop() 914 915 return expression
15def preprocess( 16 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 17) -> t.Callable[[Generator, exp.Expression], str]: 18 """ 19 Creates a new transform by chaining a sequence of transformations and converts the resulting 20 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 21 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 22 23 Args: 24 transforms: sequence of transform functions. These will be called in order. 25 26 Returns: 27 Function that can be used as a generator transform. 28 """ 29 30 def _to_sql(self, expression: exp.Expression) -> str: 31 expression_type = type(expression) 32 33 try: 34 expression = transforms[0](expression) 35 for transform in transforms[1:]: 36 expression = transform(expression) 37 except UnsupportedError as unsupported_error: 38 self.unsupported(str(unsupported_error)) 39 40 _sql_handler = getattr(self, expression.key + "_sql", None) 41 if _sql_handler: 42 return _sql_handler(expression) 43 44 transforms_handler = self.TRANSFORMS.get(type(expression)) 45 if transforms_handler: 46 if expression_type is type(expression): 47 if isinstance(expression, exp.Func): 48 return self.function_fallback_sql(expression) 49 50 # Ensures we don't enter an infinite loop. This can happen when the original expression 51 # has the same type as the final expression and there's no _sql method available for it, 52 # because then it'd re-enter _to_sql. 53 raise ValueError( 54 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 55 ) 56 57 return transforms_handler(self, expression) 58 59 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 60 61 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
64def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 65 if isinstance(expression, exp.Select): 66 count = 0 67 recursive_ctes = [] 68 69 for unnest in expression.find_all(exp.Unnest): 70 if ( 71 not isinstance(unnest.parent, (exp.From, exp.Join)) 72 or len(unnest.expressions) != 1 73 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 74 ): 75 continue 76 77 generate_date_array = unnest.expressions[0] 78 start = generate_date_array.args.get("start") 79 end = generate_date_array.args.get("end") 80 step = generate_date_array.args.get("step") 81 82 if not start or not end or not isinstance(step, exp.Interval): 83 continue 84 85 alias = unnest.args.get("alias") 86 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 87 88 start = exp.cast(start, "date") 89 date_add = exp.func( 90 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 91 ) 92 cast_date_add = exp.cast(date_add, "date") 93 94 cte_name = "_generated_dates" + (f"_{count}" if count else "") 95 96 base_query = exp.select(start.as_(column_name)) 97 recursive_query = ( 98 exp.select(cast_date_add) 99 .from_(cte_name) 100 .where(cast_date_add <= exp.cast(end, "date")) 101 ) 102 cte_query = base_query.union(recursive_query, distinct=False) 103 104 generate_dates_query = exp.select(column_name).from_(cte_name) 105 unnest.replace(generate_dates_query.subquery(cte_name)) 106 107 recursive_ctes.append( 108 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 109 ) 110 count += 1 111 112 if recursive_ctes: 113 with_expression = expression.args.get("with") or exp.With() 114 with_expression.set("recursive", True) 115 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 116 expression.set("with", with_expression) 117 118 return expression
121def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 122 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 123 this = expression.this 124 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 125 unnest = exp.Unnest(expressions=[this]) 126 if expression.alias: 127 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 128 129 return unnest 130 131 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
134def unalias_group(expression: exp.Expression) -> exp.Expression: 135 """ 136 Replace references to select aliases in GROUP BY clauses. 137 138 Example: 139 >>> import sqlglot 140 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 141 'SELECT a AS b FROM x GROUP BY 1' 142 143 Args: 144 expression: the expression that will be transformed. 145 146 Returns: 147 The transformed expression. 148 """ 149 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 150 aliased_selects = { 151 e.alias: i 152 for i, e in enumerate(expression.parent.expressions, start=1) 153 if isinstance(e, exp.Alias) 154 } 155 156 for group_by in expression.expressions: 157 if ( 158 isinstance(group_by, exp.Column) 159 and not group_by.table 160 and group_by.name in aliased_selects 161 ): 162 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 163 164 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
167def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 168 """ 169 Convert SELECT DISTINCT ON statements to a subquery with a window function. 170 171 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 172 173 Args: 174 expression: the expression that will be transformed. 175 176 Returns: 177 The transformed expression. 178 """ 179 if ( 180 isinstance(expression, exp.Select) 181 and expression.args.get("distinct") 182 and expression.args["distinct"].args.get("on") 183 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 184 ): 185 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 186 outer_selects = expression.selects 187 row_number = find_new_name(expression.named_selects, "_row_number") 188 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 189 order = expression.args.get("order") 190 191 if order: 192 window.set("order", order.pop()) 193 else: 194 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 195 196 window = exp.alias_(window, row_number) 197 expression.select(window, copy=False) 198 199 return ( 200 exp.select(*outer_selects, copy=False) 201 .from_(expression.subquery("_t", copy=False), copy=False) 202 .where(exp.column(row_number).eq(1), copy=False) 203 ) 204 205 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
208def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 209 """ 210 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 211 212 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 213 https://docs.snowflake.com/en/sql-reference/constructs/qualify 214 215 Some dialects don't support window functions in the WHERE clause, so we need to include them as 216 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 217 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 218 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 219 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 220 corresponding expression to avoid creating invalid column references. 221 """ 222 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 223 taken = set(expression.named_selects) 224 for select in expression.selects: 225 if not select.alias_or_name: 226 alias = find_new_name(taken, "_c") 227 select.replace(exp.alias_(select, alias)) 228 taken.add(alias) 229 230 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 231 alias_or_name = select.alias_or_name 232 identifier = select.args.get("alias") or select.this 233 if isinstance(identifier, exp.Identifier): 234 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 235 return alias_or_name 236 237 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 238 qualify_filters = expression.args["qualify"].pop().this 239 expression_by_alias = { 240 select.alias: select.this 241 for select in expression.selects 242 if isinstance(select, exp.Alias) 243 } 244 245 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 246 for select_candidate in qualify_filters.find_all(select_candidates): 247 if isinstance(select_candidate, exp.Window): 248 if expression_by_alias: 249 for column in select_candidate.find_all(exp.Column): 250 expr = expression_by_alias.get(column.name) 251 if expr: 252 column.replace(expr) 253 254 alias = find_new_name(expression.named_selects, "_w") 255 expression.select(exp.alias_(select_candidate, alias), copy=False) 256 column = exp.column(alias) 257 258 if isinstance(select_candidate.parent, exp.Qualify): 259 qualify_filters = column 260 else: 261 select_candidate.replace(column) 262 elif select_candidate.name not in expression.named_selects: 263 expression.select(select_candidate.copy(), copy=False) 264 265 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 266 qualify_filters, copy=False 267 ) 268 269 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
272def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 273 """ 274 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 275 other expressions. This transforms removes the precision from parameterized types in expressions. 276 """ 277 for node in expression.find_all(exp.DataType): 278 node.set( 279 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 280 ) 281 282 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
285def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 286 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 287 from sqlglot.optimizer.scope import find_all_in_scope 288 289 if isinstance(expression, exp.Select): 290 unnest_aliases = { 291 unnest.alias 292 for unnest in find_all_in_scope(expression, exp.Unnest) 293 if isinstance(unnest.parent, (exp.From, exp.Join)) 294 } 295 if unnest_aliases: 296 for column in expression.find_all(exp.Column): 297 if column.table in unnest_aliases: 298 column.set("table", None) 299 elif column.db in unnest_aliases: 300 column.set("db", None) 301 302 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
305def unnest_to_explode( 306 expression: exp.Expression, 307 unnest_using_arrays_zip: bool = True, 308) -> exp.Expression: 309 """Convert cross join unnest into lateral view explode.""" 310 311 def _unnest_zip_exprs( 312 u: exp.Unnest, unnest_exprs: t.List[exp.Expression], has_multi_expr: bool 313 ) -> t.List[exp.Expression]: 314 if has_multi_expr: 315 if not unnest_using_arrays_zip: 316 raise UnsupportedError("Cannot transpile UNNEST with multiple input arrays") 317 318 # Use INLINE(ARRAYS_ZIP(...)) for multiple expressions 319 zip_exprs: t.List[exp.Expression] = [ 320 exp.Anonymous(this="ARRAYS_ZIP", expressions=unnest_exprs) 321 ] 322 u.set("expressions", zip_exprs) 323 return zip_exprs 324 return unnest_exprs 325 326 def _udtf_type(u: exp.Unnest, has_multi_expr: bool) -> t.Type[exp.Func]: 327 if u.args.get("offset"): 328 return exp.Posexplode 329 return exp.Inline if has_multi_expr else exp.Explode 330 331 if isinstance(expression, exp.Select): 332 from_ = expression.args.get("from") 333 334 if from_ and isinstance(from_.this, exp.Unnest): 335 unnest = from_.this 336 alias = unnest.args.get("alias") 337 exprs = unnest.expressions 338 has_multi_expr = len(exprs) > 1 339 this, *expressions = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 340 341 unnest.replace( 342 exp.Table( 343 this=_udtf_type(unnest, has_multi_expr)( 344 this=this, 345 expressions=expressions, 346 ), 347 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 348 ) 349 ) 350 351 for join in expression.args.get("joins") or []: 352 join_expr = join.this 353 354 is_lateral = isinstance(join_expr, exp.Lateral) 355 356 unnest = join_expr.this if is_lateral else join_expr 357 358 if isinstance(unnest, exp.Unnest): 359 if is_lateral: 360 alias = join_expr.args.get("alias") 361 else: 362 alias = unnest.args.get("alias") 363 exprs = unnest.expressions 364 # The number of unnest.expressions will be changed by _unnest_zip_exprs, we need to record it here 365 has_multi_expr = len(exprs) > 1 366 exprs = _unnest_zip_exprs(unnest, exprs, has_multi_expr) 367 368 expression.args["joins"].remove(join) 369 370 alias_cols = alias.columns if alias else [] 371 for e, column in zip(exprs, alias_cols): 372 expression.append( 373 "laterals", 374 exp.Lateral( 375 this=_udtf_type(unnest, has_multi_expr)(this=e), 376 view=True, 377 alias=exp.TableAlias( 378 this=alias.this, # type: ignore 379 columns=alias_cols if unnest_using_arrays_zip else [column], # type: ignore 380 ), 381 ), 382 ) 383 384 return expression
Convert cross join unnest into lateral view explode.
387def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 388 """Convert explode/posexplode into unnest.""" 389 390 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 391 if isinstance(expression, exp.Select): 392 from sqlglot.optimizer.scope import Scope 393 394 taken_select_names = set(expression.named_selects) 395 taken_source_names = {name for name, _ in Scope(expression).references} 396 397 def new_name(names: t.Set[str], name: str) -> str: 398 name = find_new_name(names, name) 399 names.add(name) 400 return name 401 402 arrays: t.List[exp.Condition] = [] 403 series_alias = new_name(taken_select_names, "pos") 404 series = exp.alias_( 405 exp.Unnest( 406 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 407 ), 408 new_name(taken_source_names, "_u"), 409 table=[series_alias], 410 ) 411 412 # we use list here because expression.selects is mutated inside the loop 413 for select in list(expression.selects): 414 explode = select.find(exp.Explode) 415 416 if explode: 417 pos_alias = "" 418 explode_alias = "" 419 420 if isinstance(select, exp.Alias): 421 explode_alias = select.args["alias"] 422 alias = select 423 elif isinstance(select, exp.Aliases): 424 pos_alias = select.aliases[0] 425 explode_alias = select.aliases[1] 426 alias = select.replace(exp.alias_(select.this, "", copy=False)) 427 else: 428 alias = select.replace(exp.alias_(select, "")) 429 explode = alias.find(exp.Explode) 430 assert explode 431 432 is_posexplode = isinstance(explode, exp.Posexplode) 433 explode_arg = explode.this 434 435 if isinstance(explode, exp.ExplodeOuter): 436 bracket = explode_arg[0] 437 bracket.set("safe", True) 438 bracket.set("offset", True) 439 explode_arg = exp.func( 440 "IF", 441 exp.func( 442 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 443 ).eq(0), 444 exp.array(bracket, copy=False), 445 explode_arg, 446 ) 447 448 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 449 if isinstance(explode_arg, exp.Column): 450 taken_select_names.add(explode_arg.output_name) 451 452 unnest_source_alias = new_name(taken_source_names, "_u") 453 454 if not explode_alias: 455 explode_alias = new_name(taken_select_names, "col") 456 457 if is_posexplode: 458 pos_alias = new_name(taken_select_names, "pos") 459 460 if not pos_alias: 461 pos_alias = new_name(taken_select_names, "pos") 462 463 alias.set("alias", exp.to_identifier(explode_alias)) 464 465 series_table_alias = series.args["alias"].this 466 column = exp.If( 467 this=exp.column(series_alias, table=series_table_alias).eq( 468 exp.column(pos_alias, table=unnest_source_alias) 469 ), 470 true=exp.column(explode_alias, table=unnest_source_alias), 471 ) 472 473 explode.replace(column) 474 475 if is_posexplode: 476 expressions = expression.expressions 477 expressions.insert( 478 expressions.index(alias) + 1, 479 exp.If( 480 this=exp.column(series_alias, table=series_table_alias).eq( 481 exp.column(pos_alias, table=unnest_source_alias) 482 ), 483 true=exp.column(pos_alias, table=unnest_source_alias), 484 ).as_(pos_alias), 485 ) 486 expression.set("expressions", expressions) 487 488 if not arrays: 489 if expression.args.get("from"): 490 expression.join(series, copy=False, join_type="CROSS") 491 else: 492 expression.from_(series, copy=False) 493 494 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 495 arrays.append(size) 496 497 # trino doesn't support left join unnest with on conditions 498 # if it did, this would be much simpler 499 expression.join( 500 exp.alias_( 501 exp.Unnest( 502 expressions=[explode_arg.copy()], 503 offset=exp.to_identifier(pos_alias), 504 ), 505 unnest_source_alias, 506 table=[explode_alias], 507 ), 508 join_type="CROSS", 509 copy=False, 510 ) 511 512 if index_offset != 1: 513 size = size - 1 514 515 expression.where( 516 exp.column(series_alias, table=series_table_alias) 517 .eq(exp.column(pos_alias, table=unnest_source_alias)) 518 .or_( 519 (exp.column(series_alias, table=series_table_alias) > size).and_( 520 exp.column(pos_alias, table=unnest_source_alias).eq(size) 521 ) 522 ), 523 copy=False, 524 ) 525 526 if arrays: 527 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 528 529 if index_offset != 1: 530 end = end - (1 - index_offset) 531 series.expressions[0].set("end", end) 532 533 return expression 534 535 return _explode_to_unnest
Convert explode/posexplode into unnest.
538def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 539 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 540 if ( 541 isinstance(expression, exp.PERCENTILES) 542 and not isinstance(expression.parent, exp.WithinGroup) 543 and expression.expression 544 ): 545 column = expression.this.pop() 546 expression.set("this", expression.expression.pop()) 547 order = exp.Order(expressions=[exp.Ordered(this=column)]) 548 expression = exp.WithinGroup(this=expression, expression=order) 549 550 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
553def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 554 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 555 if ( 556 isinstance(expression, exp.WithinGroup) 557 and isinstance(expression.this, exp.PERCENTILES) 558 and isinstance(expression.expression, exp.Order) 559 ): 560 quantile = expression.this.this 561 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 562 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 563 564 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
567def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 568 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 569 if isinstance(expression, exp.With) and expression.recursive: 570 next_name = name_sequence("_c_") 571 572 for cte in expression.expressions: 573 if not cte.args["alias"].columns: 574 query = cte.this 575 if isinstance(query, exp.SetOperation): 576 query = query.this 577 578 cte.args["alias"].set( 579 "columns", 580 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 581 ) 582 583 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
586def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 587 """Replace 'epoch' in casts by the equivalent date literal.""" 588 if ( 589 isinstance(expression, (exp.Cast, exp.TryCast)) 590 and expression.name.lower() == "epoch" 591 and expression.to.this in exp.DataType.TEMPORAL_TYPES 592 ): 593 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 594 595 return expression
Replace 'epoch' in casts by the equivalent date literal.
598def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 599 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 600 if isinstance(expression, exp.Select): 601 for join in expression.args.get("joins") or []: 602 on = join.args.get("on") 603 if on and join.kind in ("SEMI", "ANTI"): 604 subquery = exp.select("1").from_(join.this).where(on) 605 exists = exp.Exists(this=subquery) 606 if join.kind == "ANTI": 607 exists = exists.not_(copy=False) 608 609 join.pop() 610 expression.where(exists, copy=False) 611 612 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
615def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 616 """ 617 Converts a query with a FULL OUTER join to a union of identical queries that 618 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 619 for queries that have a single FULL OUTER join. 620 """ 621 if isinstance(expression, exp.Select): 622 full_outer_joins = [ 623 (index, join) 624 for index, join in enumerate(expression.args.get("joins") or []) 625 if join.side == "FULL" 626 ] 627 628 if len(full_outer_joins) == 1: 629 expression_copy = expression.copy() 630 expression.set("limit", None) 631 index, full_outer_join = full_outer_joins[0] 632 633 tables = (expression.args["from"].alias_or_name, full_outer_join.alias_or_name) 634 join_conditions = full_outer_join.args.get("on") or exp.and_( 635 *[ 636 exp.column(col, tables[0]).eq(exp.column(col, tables[1])) 637 for col in full_outer_join.args.get("using") 638 ] 639 ) 640 641 full_outer_join.set("side", "left") 642 anti_join_clause = exp.select("1").from_(expression.args["from"]).where(join_conditions) 643 expression_copy.args["joins"][index].set("side", "right") 644 expression_copy = expression_copy.where(exp.Exists(this=anti_join_clause).not_()) 645 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 646 expression.args.pop("order", None) # remove order by from LEFT side 647 648 return exp.union(expression, expression_copy, copy=False, distinct=False) 649 650 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
653def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 654 """ 655 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 656 defined at the top-level, so for example queries like: 657 658 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 659 660 are invalid in those dialects. This transformation can be used to ensure all CTEs are 661 moved to the top level so that the final SQL code is valid from a syntax standpoint. 662 663 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 664 """ 665 top_level_with = expression.args.get("with") 666 for inner_with in expression.find_all(exp.With): 667 if inner_with.parent is expression: 668 continue 669 670 if not top_level_with: 671 top_level_with = inner_with.pop() 672 expression.set("with", top_level_with) 673 else: 674 if inner_with.recursive: 675 top_level_with.set("recursive", True) 676 677 parent_cte = inner_with.find_ancestor(exp.CTE) 678 inner_with.pop() 679 680 if parent_cte: 681 i = top_level_with.expressions.index(parent_cte) 682 top_level_with.expressions[i:i] = inner_with.expressions 683 top_level_with.set("expressions", top_level_with.expressions) 684 else: 685 top_level_with.set( 686 "expressions", top_level_with.expressions + inner_with.expressions 687 ) 688 689 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
692def ensure_bools(expression: exp.Expression) -> exp.Expression: 693 """Converts numeric values used in conditions into explicit boolean expressions.""" 694 from sqlglot.optimizer.canonicalize import ensure_bools 695 696 def _ensure_bool(node: exp.Expression) -> None: 697 if ( 698 node.is_number 699 or ( 700 not isinstance(node, exp.SubqueryPredicate) 701 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 702 ) 703 or (isinstance(node, exp.Column) and not node.type) 704 ): 705 node.replace(node.neq(0)) 706 707 for node in expression.walk(): 708 ensure_bools(node, _ensure_bool) 709 710 return expression
Converts numeric values used in conditions into explicit boolean expressions.
731def ctas_with_tmp_tables_to_create_tmp_view( 732 expression: exp.Expression, 733 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 734) -> exp.Expression: 735 assert isinstance(expression, exp.Create) 736 properties = expression.args.get("properties") 737 temporary = any( 738 isinstance(prop, exp.TemporaryProperty) 739 for prop in (properties.expressions if properties else []) 740 ) 741 742 # CTAS with temp tables map to CREATE TEMPORARY VIEW 743 if expression.kind == "TABLE" and temporary: 744 if expression.expression: 745 return exp.Create( 746 kind="TEMPORARY VIEW", 747 this=expression.this, 748 expression=expression.expression, 749 ) 750 return tmp_storage_provider(expression) 751 752 return expression
755def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 756 """ 757 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 758 PARTITIONED BY value is an array of column names, they are transformed into a schema. 759 The corresponding columns are removed from the create statement. 760 """ 761 assert isinstance(expression, exp.Create) 762 has_schema = isinstance(expression.this, exp.Schema) 763 is_partitionable = expression.kind in {"TABLE", "VIEW"} 764 765 if has_schema and is_partitionable: 766 prop = expression.find(exp.PartitionedByProperty) 767 if prop and prop.this and not isinstance(prop.this, exp.Schema): 768 schema = expression.this 769 columns = {v.name.upper() for v in prop.this.expressions} 770 partitions = [col for col in schema.expressions if col.name.upper() in columns] 771 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 772 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 773 expression.set("this", schema) 774 775 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
778def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 779 """ 780 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 781 782 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 783 """ 784 assert isinstance(expression, exp.Create) 785 prop = expression.find(exp.PartitionedByProperty) 786 if ( 787 prop 788 and prop.this 789 and isinstance(prop.this, exp.Schema) 790 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 791 ): 792 prop_this = exp.Tuple( 793 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 794 ) 795 schema = expression.this 796 for e in prop.this.expressions: 797 schema.append("expressions", e) 798 prop.set("this", prop_this) 799 800 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
803def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 804 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 805 if isinstance(expression, exp.Struct): 806 expression.set( 807 "expressions", 808 [ 809 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 810 for e in expression.expressions 811 ], 812 ) 813 814 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
817def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 818 """ 819 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 820 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 821 822 For example, 823 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 824 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 825 826 Args: 827 expression: The AST to remove join marks from. 828 829 Returns: 830 The AST with join marks removed. 831 """ 832 from sqlglot.optimizer.scope import traverse_scope 833 834 for scope in traverse_scope(expression): 835 query = scope.expression 836 837 where = query.args.get("where") 838 joins = query.args.get("joins") 839 840 if not where or not joins: 841 continue 842 843 query_from = query.args["from"] 844 845 # These keep track of the joins to be replaced 846 new_joins: t.Dict[str, exp.Join] = {} 847 old_joins = {join.alias_or_name: join for join in joins} 848 849 for column in scope.columns: 850 if not column.args.get("join_mark"): 851 continue 852 853 predicate = column.find_ancestor(exp.Predicate, exp.Select) 854 assert isinstance( 855 predicate, exp.Binary 856 ), "Columns can only be marked with (+) when involved in a binary operation" 857 858 predicate_parent = predicate.parent 859 join_predicate = predicate.pop() 860 861 left_columns = [ 862 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 863 ] 864 right_columns = [ 865 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 866 ] 867 868 assert not ( 869 left_columns and right_columns 870 ), "The (+) marker cannot appear in both sides of a binary predicate" 871 872 marked_column_tables = set() 873 for col in left_columns or right_columns: 874 table = col.table 875 assert table, f"Column {col} needs to be qualified with a table" 876 877 col.set("join_mark", False) 878 marked_column_tables.add(table) 879 880 assert ( 881 len(marked_column_tables) == 1 882 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 883 884 join_this = old_joins.get(col.table, query_from).this 885 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 886 887 # Upsert new_join into new_joins dictionary 888 new_join_alias_or_name = new_join.alias_or_name 889 existing_join = new_joins.get(new_join_alias_or_name) 890 if existing_join: 891 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 892 else: 893 new_joins[new_join_alias_or_name] = new_join 894 895 # If the parent of the target predicate is a binary node, then it now has only one child 896 if isinstance(predicate_parent, exp.Binary): 897 if predicate_parent.left is None: 898 predicate_parent.replace(predicate_parent.right) 899 else: 900 predicate_parent.replace(predicate_parent.left) 901 902 if query_from.alias_or_name in new_joins: 903 only_old_joins = old_joins.keys() - new_joins.keys() 904 assert ( 905 len(only_old_joins) >= 1 906 ), "Cannot determine which table to use in the new FROM clause" 907 908 new_from_name = list(only_old_joins)[0] 909 query.set("from", exp.From(this=old_joins[new_from_name].this)) 910 911 query.set("joins", list(new_joins.values())) 912 913 if not where.this: 914 where.pop() 915 916 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.