sqlglot.optimizer.simplify
1from __future__ import annotations 2 3import datetime 4import functools 5import itertools 6import typing as t 7from collections import deque 8from decimal import Decimal 9 10import sqlglot 11from sqlglot import Dialect, exp 12from sqlglot.helper import first, merge_ranges, while_changing 13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope 14 15if t.TYPE_CHECKING: 16 from sqlglot.dialects.dialect import DialectType 17 18 DateTruncBinaryTransform = t.Callable[ 19 [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression] 20 ] 21 22# Final means that an expression should not be simplified 23FINAL = "final" 24 25 26class UnsupportedUnit(Exception): 27 pass 28 29 30def simplify( 31 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 32): 33 """ 34 Rewrite sqlglot AST to simplify expressions. 35 36 Example: 37 >>> import sqlglot 38 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 39 >>> simplify(expression).sql() 40 'TRUE' 41 42 Args: 43 expression (sqlglot.Expression): expression to simplify 44 constant_propagation: whether the constant propagation rule should be used 45 46 Returns: 47 sqlglot.Expression: simplified expression 48 """ 49 50 dialect = Dialect.get_or_raise(dialect) 51 52 def _simplify(expression, root=True): 53 if expression.meta.get(FINAL): 54 return expression 55 56 # group by expressions cannot be simplified, for example 57 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 58 # the projection must exactly match the group by key 59 group = expression.args.get("group") 60 61 if group and hasattr(expression, "selects"): 62 groups = set(group.expressions) 63 group.meta[FINAL] = True 64 65 for e in expression.selects: 66 for node in e.walk(): 67 if node in groups: 68 e.meta[FINAL] = True 69 break 70 71 having = expression.args.get("having") 72 if having: 73 for node in having.walk(): 74 if node in groups: 75 having.meta[FINAL] = True 76 break 77 78 # Pre-order transformations 79 node = expression 80 node = rewrite_between(node) 81 node = uniq_sort(node, root) 82 node = absorb_and_eliminate(node, root) 83 node = simplify_concat(node) 84 node = simplify_conditionals(node) 85 86 if constant_propagation: 87 node = propagate_constants(node, root) 88 89 exp.replace_children(node, lambda e: _simplify(e, False)) 90 91 # Post-order transformations 92 node = simplify_not(node) 93 node = flatten(node) 94 node = simplify_connectors(node, root) 95 node = remove_complements(node, root) 96 node = simplify_coalesce(node) 97 node.parent = expression.parent 98 node = simplify_literals(node, root) 99 node = simplify_equality(node) 100 node = simplify_parens(node) 101 node = simplify_datetrunc(node, dialect) 102 node = sort_comparison(node) 103 node = simplify_startswith(node) 104 105 if root: 106 expression.replace(node) 107 return node 108 109 expression = while_changing(expression, _simplify) 110 remove_where_true(expression) 111 return expression 112 113 114def catch(*exceptions): 115 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 116 117 def decorator(func): 118 def wrapped(expression, *args, **kwargs): 119 try: 120 return func(expression, *args, **kwargs) 121 except exceptions: 122 return expression 123 124 return wrapped 125 126 return decorator 127 128 129def rewrite_between(expression: exp.Expression) -> exp.Expression: 130 """Rewrite x between y and z to x >= y AND x <= z. 131 132 This is done because comparison simplification is only done on lt/lte/gt/gte. 133 """ 134 if isinstance(expression, exp.Between): 135 negate = isinstance(expression.parent, exp.Not) 136 137 expression = exp.and_( 138 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 139 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 140 copy=False, 141 ) 142 143 if negate: 144 expression = exp.paren(expression, copy=False) 145 146 return expression 147 148 149COMPLEMENT_COMPARISONS = { 150 exp.LT: exp.GTE, 151 exp.GT: exp.LTE, 152 exp.LTE: exp.GT, 153 exp.GTE: exp.LT, 154 exp.EQ: exp.NEQ, 155 exp.NEQ: exp.EQ, 156} 157 158 159def simplify_not(expression): 160 """ 161 Demorgan's Law 162 NOT (x OR y) -> NOT x AND NOT y 163 NOT (x AND y) -> NOT x OR NOT y 164 """ 165 if isinstance(expression, exp.Not): 166 this = expression.this 167 if is_null(this): 168 return exp.null() 169 if this.__class__ in COMPLEMENT_COMPARISONS: 170 return COMPLEMENT_COMPARISONS[this.__class__]( 171 this=this.this, expression=this.expression 172 ) 173 if isinstance(this, exp.Paren): 174 condition = this.unnest() 175 if isinstance(condition, exp.And): 176 return exp.paren( 177 exp.or_( 178 exp.not_(condition.left, copy=False), 179 exp.not_(condition.right, copy=False), 180 copy=False, 181 ) 182 ) 183 if isinstance(condition, exp.Or): 184 return exp.paren( 185 exp.and_( 186 exp.not_(condition.left, copy=False), 187 exp.not_(condition.right, copy=False), 188 copy=False, 189 ) 190 ) 191 if is_null(condition): 192 return exp.null() 193 if always_true(this): 194 return exp.false() 195 if is_false(this): 196 return exp.true() 197 if isinstance(this, exp.Not): 198 # double negation 199 # NOT NOT x -> x 200 return this.this 201 return expression 202 203 204def flatten(expression): 205 """ 206 A AND (B AND C) -> A AND B AND C 207 A OR (B OR C) -> A OR B OR C 208 """ 209 if isinstance(expression, exp.Connector): 210 for node in expression.args.values(): 211 child = node.unnest() 212 if isinstance(child, expression.__class__): 213 node.replace(child) 214 return expression 215 216 217def simplify_connectors(expression, root=True): 218 def _simplify_connectors(expression, left, right): 219 if left == right: 220 return left 221 if isinstance(expression, exp.And): 222 if is_false(left) or is_false(right): 223 return exp.false() 224 if is_null(left) or is_null(right): 225 return exp.null() 226 if always_true(left) and always_true(right): 227 return exp.true() 228 if always_true(left): 229 return right 230 if always_true(right): 231 return left 232 return _simplify_comparison(expression, left, right) 233 elif isinstance(expression, exp.Or): 234 if always_true(left) or always_true(right): 235 return exp.true() 236 if is_false(left) and is_false(right): 237 return exp.false() 238 if ( 239 (is_null(left) and is_null(right)) 240 or (is_null(left) and is_false(right)) 241 or (is_false(left) and is_null(right)) 242 ): 243 return exp.null() 244 if is_false(left): 245 return right 246 if is_false(right): 247 return left 248 return _simplify_comparison(expression, left, right, or_=True) 249 250 if isinstance(expression, exp.Connector): 251 return _flat_simplify(expression, _simplify_connectors, root) 252 return expression 253 254 255LT_LTE = (exp.LT, exp.LTE) 256GT_GTE = (exp.GT, exp.GTE) 257 258COMPARISONS = ( 259 *LT_LTE, 260 *GT_GTE, 261 exp.EQ, 262 exp.NEQ, 263 exp.Is, 264) 265 266INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 267 exp.LT: exp.GT, 268 exp.GT: exp.LT, 269 exp.LTE: exp.GTE, 270 exp.GTE: exp.LTE, 271} 272 273NONDETERMINISTIC = (exp.Rand, exp.Randn) 274 275 276def _simplify_comparison(expression, left, right, or_=False): 277 if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS): 278 ll, lr = left.args.values() 279 rl, rr = right.args.values() 280 281 largs = {ll, lr} 282 rargs = {rl, rr} 283 284 matching = largs & rargs 285 columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)} 286 287 if matching and columns: 288 try: 289 l = first(largs - columns) 290 r = first(rargs - columns) 291 except StopIteration: 292 return expression 293 294 if l.is_number and r.is_number: 295 l = float(l.name) 296 r = float(r.name) 297 elif l.is_string and r.is_string: 298 l = l.name 299 r = r.name 300 else: 301 l = extract_date(l) 302 if not l: 303 return None 304 r = extract_date(r) 305 if not r: 306 return None 307 308 for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))): 309 if isinstance(a, LT_LTE) and isinstance(b, LT_LTE): 310 return left if (av > bv if or_ else av <= bv) else right 311 if isinstance(a, GT_GTE) and isinstance(b, GT_GTE): 312 return left if (av < bv if or_ else av >= bv) else right 313 314 # we can't ever shortcut to true because the column could be null 315 if not or_: 316 if isinstance(a, exp.LT) and isinstance(b, GT_GTE): 317 if av <= bv: 318 return exp.false() 319 elif isinstance(a, exp.GT) and isinstance(b, LT_LTE): 320 if av >= bv: 321 return exp.false() 322 elif isinstance(a, exp.EQ): 323 if isinstance(b, exp.LT): 324 return exp.false() if av >= bv else a 325 if isinstance(b, exp.LTE): 326 return exp.false() if av > bv else a 327 if isinstance(b, exp.GT): 328 return exp.false() if av <= bv else a 329 if isinstance(b, exp.GTE): 330 return exp.false() if av < bv else a 331 if isinstance(b, exp.NEQ): 332 return exp.false() if av == bv else a 333 return None 334 335 336def remove_complements(expression, root=True): 337 """ 338 Removing complements. 339 340 A AND NOT A -> FALSE 341 A OR NOT A -> TRUE 342 """ 343 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 344 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 345 346 for a, b in itertools.permutations(expression.flatten(), 2): 347 if is_complement(a, b): 348 return complement 349 return expression 350 351 352def uniq_sort(expression, root=True): 353 """ 354 Uniq and sort a connector. 355 356 C AND A AND B AND B -> A AND B AND C 357 """ 358 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 359 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 360 flattened = tuple(expression.flatten()) 361 deduped = {gen(e): e for e in flattened} 362 arr = tuple(deduped.items()) 363 364 # check if the operands are already sorted, if not sort them 365 # A AND C AND B -> A AND B AND C 366 for i, (sql, e) in enumerate(arr[1:]): 367 if sql < arr[i][0]: 368 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 369 break 370 else: 371 # we didn't have to sort but maybe we need to dedup 372 if len(deduped) < len(flattened): 373 expression = result_func(*deduped.values(), copy=False) 374 375 return expression 376 377 378def absorb_and_eliminate(expression, root=True): 379 """ 380 absorption: 381 A AND (A OR B) -> A 382 A OR (A AND B) -> A 383 A AND (NOT A OR B) -> A AND B 384 A OR (NOT A AND B) -> A OR B 385 elimination: 386 (A AND B) OR (A AND NOT B) -> A 387 (A OR B) AND (A OR NOT B) -> A 388 """ 389 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 390 kind = exp.Or if isinstance(expression, exp.And) else exp.And 391 392 for a, b in itertools.permutations(expression.flatten(), 2): 393 if isinstance(a, kind): 394 aa, ab = a.unnest_operands() 395 396 # absorb 397 if is_complement(b, aa): 398 aa.replace(exp.true() if kind == exp.And else exp.false()) 399 elif is_complement(b, ab): 400 ab.replace(exp.true() if kind == exp.And else exp.false()) 401 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 402 a.replace(exp.false() if kind == exp.And else exp.true()) 403 elif isinstance(b, kind): 404 # eliminate 405 rhs = b.unnest_operands() 406 ba, bb = rhs 407 408 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 409 a.replace(aa) 410 b.replace(aa) 411 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 412 a.replace(ab) 413 b.replace(ab) 414 415 return expression 416 417 418def propagate_constants(expression, root=True): 419 """ 420 Propagate constants for conjunctions in DNF: 421 422 SELECT * FROM t WHERE a = b AND b = 5 becomes 423 SELECT * FROM t WHERE a = 5 AND b = 5 424 425 Reference: https://www.sqlite.org/optoverview.html 426 """ 427 428 if ( 429 isinstance(expression, exp.And) 430 and (root or not expression.same_parent) 431 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 432 ): 433 constant_mapping = {} 434 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 435 if isinstance(expr, exp.EQ): 436 l, r = expr.left, expr.right 437 438 # TODO: create a helper that can be used to detect nested literal expressions such 439 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 440 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 441 constant_mapping[l] = (id(l), r) 442 443 if constant_mapping: 444 for column in find_all_in_scope(expression, exp.Column): 445 parent = column.parent 446 column_id, constant = constant_mapping.get(column) or (None, None) 447 if ( 448 column_id is not None 449 and id(column) != column_id 450 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 451 ): 452 column.replace(constant.copy()) 453 454 return expression 455 456 457INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 458 exp.DateAdd: exp.Sub, 459 exp.DateSub: exp.Add, 460 exp.DatetimeAdd: exp.Sub, 461 exp.DatetimeSub: exp.Add, 462} 463 464INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = { 465 **INVERSE_DATE_OPS, 466 exp.Add: exp.Sub, 467 exp.Sub: exp.Add, 468} 469 470 471def _is_number(expression: exp.Expression) -> bool: 472 return expression.is_number 473 474 475def _is_interval(expression: exp.Expression) -> bool: 476 return isinstance(expression, exp.Interval) and extract_interval(expression) is not None 477 478 479@catch(ModuleNotFoundError, UnsupportedUnit) 480def simplify_equality(expression: exp.Expression) -> exp.Expression: 481 """ 482 Use the subtraction and addition properties of equality to simplify expressions: 483 484 x + 1 = 3 becomes x = 2 485 486 There are two binary operations in the above expression: + and = 487 Here's how we reference all the operands in the code below: 488 489 l r 490 x + 1 = 3 491 a b 492 """ 493 if isinstance(expression, COMPARISONS): 494 l, r = expression.left, expression.right 495 496 if l.__class__ not in INVERSE_OPS: 497 return expression 498 499 if r.is_number: 500 a_predicate = _is_number 501 b_predicate = _is_number 502 elif _is_date_literal(r): 503 a_predicate = _is_date_literal 504 b_predicate = _is_interval 505 else: 506 return expression 507 508 if l.__class__ in INVERSE_DATE_OPS: 509 l = t.cast(exp.IntervalOp, l) 510 a = l.this 511 b = l.interval() 512 else: 513 l = t.cast(exp.Binary, l) 514 a, b = l.left, l.right 515 516 if not a_predicate(a) and b_predicate(b): 517 pass 518 elif not a_predicate(b) and b_predicate(a): 519 a, b = b, a 520 else: 521 return expression 522 523 return expression.__class__( 524 this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b) 525 ) 526 return expression 527 528 529def simplify_literals(expression, root=True): 530 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 531 return _flat_simplify(expression, _simplify_binary, root) 532 533 if isinstance(expression, exp.Neg): 534 this = expression.this 535 if this.is_number: 536 value = this.name 537 if value[0] == "-": 538 return exp.Literal.number(value[1:]) 539 return exp.Literal.number(f"-{value}") 540 541 if type(expression) in INVERSE_DATE_OPS: 542 return _simplify_binary(expression, expression.this, expression.interval()) or expression 543 544 return expression 545 546 547NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ) 548 549 550def _simplify_binary(expression, a, b): 551 if isinstance(expression, exp.Is): 552 if isinstance(b, exp.Not): 553 c = b.this 554 not_ = True 555 else: 556 c = b 557 not_ = False 558 559 if is_null(c): 560 if isinstance(a, exp.Literal): 561 return exp.true() if not_ else exp.false() 562 if is_null(a): 563 return exp.false() if not_ else exp.true() 564 elif isinstance(expression, NULL_OK): 565 return None 566 elif is_null(a) or is_null(b): 567 return exp.null() 568 569 if a.is_number and b.is_number: 570 num_a = int(a.name) if a.is_int else Decimal(a.name) 571 num_b = int(b.name) if b.is_int else Decimal(b.name) 572 573 if isinstance(expression, exp.Add): 574 return exp.Literal.number(num_a + num_b) 575 if isinstance(expression, exp.Mul): 576 return exp.Literal.number(num_a * num_b) 577 578 # We only simplify Sub, Div if a and b have the same parent because they're not associative 579 if isinstance(expression, exp.Sub): 580 return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None 581 if isinstance(expression, exp.Div): 582 # engines have differing int div behavior so intdiv is not safe 583 if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent: 584 return None 585 return exp.Literal.number(num_a / num_b) 586 587 boolean = eval_boolean(expression, num_a, num_b) 588 589 if boolean: 590 return boolean 591 elif a.is_string and b.is_string: 592 boolean = eval_boolean(expression, a.this, b.this) 593 594 if boolean: 595 return boolean 596 elif _is_date_literal(a) and isinstance(b, exp.Interval): 597 a, b = extract_date(a), extract_interval(b) 598 if a and b: 599 if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)): 600 return date_literal(a + b) 601 if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)): 602 return date_literal(a - b) 603 elif isinstance(a, exp.Interval) and _is_date_literal(b): 604 a, b = extract_interval(a), extract_date(b) 605 # you cannot subtract a date from an interval 606 if a and b and isinstance(expression, exp.Add): 607 return date_literal(a + b) 608 elif _is_date_literal(a) and _is_date_literal(b): 609 if isinstance(expression, exp.Predicate): 610 a, b = extract_date(a), extract_date(b) 611 boolean = eval_boolean(expression, a, b) 612 if boolean: 613 return boolean 614 615 return None 616 617 618def simplify_parens(expression): 619 if not isinstance(expression, exp.Paren): 620 return expression 621 622 this = expression.this 623 parent = expression.parent 624 parent_is_predicate = isinstance(parent, exp.Predicate) 625 626 if not isinstance(this, exp.Select) and ( 627 not isinstance(parent, (exp.Condition, exp.Binary)) 628 or isinstance(parent, exp.Paren) 629 or ( 630 not isinstance(this, exp.Binary) 631 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 632 ) 633 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 634 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 635 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 636 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 637 ): 638 return this 639 return expression 640 641 642NONNULL_CONSTANTS = ( 643 exp.Literal, 644 exp.Boolean, 645) 646 647CONSTANTS = ( 648 exp.Literal, 649 exp.Boolean, 650 exp.Null, 651) 652 653 654def _is_nonnull_constant(expression: exp.Expression) -> bool: 655 return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression) 656 657 658def _is_constant(expression: exp.Expression) -> bool: 659 return isinstance(expression, CONSTANTS) or _is_date_literal(expression) 660 661 662def simplify_coalesce(expression): 663 # COALESCE(x) -> x 664 if ( 665 isinstance(expression, exp.Coalesce) 666 and (not expression.expressions or _is_nonnull_constant(expression.this)) 667 # COALESCE is also used as a Spark partitioning hint 668 and not isinstance(expression.parent, exp.Hint) 669 ): 670 return expression.this 671 672 if not isinstance(expression, COMPARISONS): 673 return expression 674 675 if isinstance(expression.left, exp.Coalesce): 676 coalesce = expression.left 677 other = expression.right 678 elif isinstance(expression.right, exp.Coalesce): 679 coalesce = expression.right 680 other = expression.left 681 else: 682 return expression 683 684 # This transformation is valid for non-constants, 685 # but it really only does anything if they are both constants. 686 if not _is_constant(other): 687 return expression 688 689 # Find the first constant arg 690 for arg_index, arg in enumerate(coalesce.expressions): 691 if _is_constant(arg): 692 break 693 else: 694 return expression 695 696 coalesce.set("expressions", coalesce.expressions[:arg_index]) 697 698 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 699 # since we already remove COALESCE at the top of this function. 700 coalesce = coalesce if coalesce.expressions else coalesce.this 701 702 # This expression is more complex than when we started, but it will get simplified further 703 return exp.paren( 704 exp.or_( 705 exp.and_( 706 coalesce.is_(exp.null()).not_(copy=False), 707 expression.copy(), 708 copy=False, 709 ), 710 exp.and_( 711 coalesce.is_(exp.null()), 712 type(expression)(this=arg.copy(), expression=other.copy()), 713 copy=False, 714 ), 715 copy=False, 716 ) 717 ) 718 719 720CONCATS = (exp.Concat, exp.DPipe) 721 722 723def simplify_concat(expression): 724 """Reduces all groups that contain string literals by concatenating them.""" 725 if not isinstance(expression, CONCATS) or ( 726 # We can't reduce a CONCAT_WS call if we don't statically know the separator 727 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 728 ): 729 return expression 730 731 if isinstance(expression, exp.ConcatWs): 732 sep_expr, *expressions = expression.expressions 733 sep = sep_expr.name 734 concat_type = exp.ConcatWs 735 args = {} 736 else: 737 expressions = expression.expressions 738 sep = "" 739 concat_type = exp.Concat 740 args = { 741 "safe": expression.args.get("safe"), 742 "coalesce": expression.args.get("coalesce"), 743 } 744 745 new_args = [] 746 for is_string_group, group in itertools.groupby( 747 expressions or expression.flatten(), lambda e: e.is_string 748 ): 749 if is_string_group: 750 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 751 else: 752 new_args.extend(group) 753 754 if len(new_args) == 1 and new_args[0].is_string: 755 return new_args[0] 756 757 if concat_type is exp.ConcatWs: 758 new_args = [sep_expr] + new_args 759 760 return concat_type(expressions=new_args, **args) 761 762 763def simplify_conditionals(expression): 764 """Simplifies expressions like IF, CASE if their condition is statically known.""" 765 if isinstance(expression, exp.Case): 766 this = expression.this 767 for case in expression.args["ifs"]: 768 cond = case.this 769 if this: 770 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 771 cond = cond.replace(this.pop().eq(cond)) 772 773 if always_true(cond): 774 return case.args["true"] 775 776 if always_false(cond): 777 case.pop() 778 if not expression.args["ifs"]: 779 return expression.args.get("default") or exp.null() 780 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 781 if always_true(expression.this): 782 return expression.args["true"] 783 if always_false(expression.this): 784 return expression.args.get("false") or exp.null() 785 786 return expression 787 788 789def simplify_startswith(expression: exp.Expression) -> exp.Expression: 790 """ 791 Reduces a prefix check to either TRUE or FALSE if both the string and the 792 prefix are statically known. 793 794 Example: 795 >>> from sqlglot import parse_one 796 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 797 'TRUE' 798 """ 799 if ( 800 isinstance(expression, exp.StartsWith) 801 and expression.this.is_string 802 and expression.expression.is_string 803 ): 804 return exp.convert(expression.name.startswith(expression.expression.name)) 805 806 return expression 807 808 809DateRange = t.Tuple[datetime.date, datetime.date] 810 811 812def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]: 813 """ 814 Get the date range for a DATE_TRUNC equality comparison: 815 816 Example: 817 _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01)) 818 Returns: 819 tuple of [min, max) or None if a value can never be equal to `date` for `unit` 820 """ 821 floor = date_floor(date, unit, dialect) 822 823 if date != floor: 824 # This will always be False, except for NULL values. 825 return None 826 827 return floor, floor + interval(unit) 828 829 830def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression: 831 """Get the logical expression for a date range""" 832 return exp.and_( 833 left >= date_literal(drange[0]), 834 left < date_literal(drange[1]), 835 copy=False, 836 ) 837 838 839def _datetrunc_eq( 840 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 841) -> t.Optional[exp.Expression]: 842 drange = _datetrunc_range(date, unit, dialect) 843 if not drange: 844 return None 845 846 return _datetrunc_eq_expression(left, drange) 847 848 849def _datetrunc_neq( 850 left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect 851) -> t.Optional[exp.Expression]: 852 drange = _datetrunc_range(date, unit, dialect) 853 if not drange: 854 return None 855 856 return exp.and_( 857 left < date_literal(drange[0]), 858 left >= date_literal(drange[1]), 859 copy=False, 860 ) 861 862 863DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = { 864 exp.LT: lambda l, dt, u, d: l 865 < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)), 866 exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)), 867 exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)), 868 exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)), 869 exp.EQ: _datetrunc_eq, 870 exp.NEQ: _datetrunc_neq, 871} 872DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS} 873DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc) 874 875 876def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool: 877 return isinstance(left, DATETRUNCS) and _is_date_literal(right) 878 879 880@catch(ModuleNotFoundError, UnsupportedUnit) 881def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression: 882 """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`""" 883 comparison = expression.__class__ 884 885 if isinstance(expression, DATETRUNCS): 886 date = extract_date(expression.this) 887 if date and expression.unit: 888 return date_literal(date_floor(date, expression.unit.name.lower(), dialect)) 889 elif comparison not in DATETRUNC_COMPARISONS: 890 return expression 891 892 if isinstance(expression, exp.Binary): 893 l, r = expression.left, expression.right 894 895 if not _is_datetrunc_predicate(l, r): 896 return expression 897 898 l = t.cast(exp.DateTrunc, l) 899 unit = l.unit.name.lower() 900 date = extract_date(r) 901 902 if not date: 903 return expression 904 905 return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression 906 elif isinstance(expression, exp.In): 907 l = expression.this 908 rs = expression.expressions 909 910 if rs and all(_is_datetrunc_predicate(l, r) for r in rs): 911 l = t.cast(exp.DateTrunc, l) 912 unit = l.unit.name.lower() 913 914 ranges = [] 915 for r in rs: 916 date = extract_date(r) 917 if not date: 918 return expression 919 drange = _datetrunc_range(date, unit, dialect) 920 if drange: 921 ranges.append(drange) 922 923 if not ranges: 924 return expression 925 926 ranges = merge_ranges(ranges) 927 928 return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False) 929 930 return expression 931 932 933def sort_comparison(expression: exp.Expression) -> exp.Expression: 934 if expression.__class__ in COMPLEMENT_COMPARISONS: 935 l, r = expression.this, expression.expression 936 l_column = isinstance(l, exp.Column) 937 r_column = isinstance(r, exp.Column) 938 l_const = _is_constant(l) 939 r_const = _is_constant(r) 940 941 if (l_column and not r_column) or (r_const and not l_const): 942 return expression 943 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 944 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 945 this=r, expression=l 946 ) 947 return expression 948 949 950# CROSS joins result in an empty table if the right table is empty. 951# So we can only simplify certain types of joins to CROSS. 952# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x 953JOINS = { 954 ("", ""), 955 ("", "INNER"), 956 ("RIGHT", ""), 957 ("RIGHT", "OUTER"), 958} 959 960 961def remove_where_true(expression): 962 for where in expression.find_all(exp.Where): 963 if always_true(where.this): 964 where.pop() 965 for join in expression.find_all(exp.Join): 966 if ( 967 always_true(join.args.get("on")) 968 and not join.args.get("using") 969 and not join.args.get("method") 970 and (join.side, join.kind) in JOINS 971 ): 972 join.args["on"].pop() 973 join.set("side", None) 974 join.set("kind", "CROSS") 975 976 977def always_true(expression): 978 return (isinstance(expression, exp.Boolean) and expression.this) or isinstance( 979 expression, exp.Literal 980 ) 981 982 983def always_false(expression): 984 return is_false(expression) or is_null(expression) 985 986 987def is_complement(a, b): 988 return isinstance(b, exp.Not) and b.this == a 989 990 991def is_false(a: exp.Expression) -> bool: 992 return type(a) is exp.Boolean and not a.this 993 994 995def is_null(a: exp.Expression) -> bool: 996 return type(a) is exp.Null 997 998 999def eval_boolean(expression, a, b): 1000 if isinstance(expression, (exp.EQ, exp.Is)): 1001 return boolean_literal(a == b) 1002 if isinstance(expression, exp.NEQ): 1003 return boolean_literal(a != b) 1004 if isinstance(expression, exp.GT): 1005 return boolean_literal(a > b) 1006 if isinstance(expression, exp.GTE): 1007 return boolean_literal(a >= b) 1008 if isinstance(expression, exp.LT): 1009 return boolean_literal(a < b) 1010 if isinstance(expression, exp.LTE): 1011 return boolean_literal(a <= b) 1012 return None 1013 1014 1015def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1016 if isinstance(value, datetime.datetime): 1017 return value.date() 1018 if isinstance(value, datetime.date): 1019 return value 1020 try: 1021 return datetime.datetime.fromisoformat(value).date() 1022 except ValueError: 1023 return None 1024 1025 1026def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1027 if isinstance(value, datetime.datetime): 1028 return value 1029 if isinstance(value, datetime.date): 1030 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1031 try: 1032 return datetime.datetime.fromisoformat(value) 1033 except ValueError: 1034 return None 1035 1036 1037def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1038 if not value: 1039 return None 1040 if to.is_type(exp.DataType.Type.DATE): 1041 return cast_as_date(value) 1042 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1043 return cast_as_datetime(value) 1044 return None 1045 1046 1047def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1048 if isinstance(cast, exp.Cast): 1049 to = cast.to 1050 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1051 to = exp.DataType.build(exp.DataType.Type.DATE) 1052 else: 1053 return None 1054 1055 if isinstance(cast.this, exp.Literal): 1056 value: t.Any = cast.this.name 1057 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1058 value = extract_date(cast.this) 1059 else: 1060 return None 1061 return cast_value(value, to) 1062 1063 1064def _is_date_literal(expression: exp.Expression) -> bool: 1065 return extract_date(expression) is not None 1066 1067 1068def extract_interval(expression): 1069 try: 1070 n = int(expression.name) 1071 unit = expression.text("unit").lower() 1072 return interval(unit, n) 1073 except (UnsupportedUnit, ModuleNotFoundError, ValueError): 1074 return None 1075 1076 1077def date_literal(date): 1078 return exp.cast( 1079 exp.Literal.string(date), 1080 ( 1081 exp.DataType.Type.DATETIME 1082 if isinstance(date, datetime.datetime) 1083 else exp.DataType.Type.DATE 1084 ), 1085 ) 1086 1087 1088def interval(unit: str, n: int = 1): 1089 from dateutil.relativedelta import relativedelta 1090 1091 if unit == "year": 1092 return relativedelta(years=1 * n) 1093 if unit == "quarter": 1094 return relativedelta(months=3 * n) 1095 if unit == "month": 1096 return relativedelta(months=1 * n) 1097 if unit == "week": 1098 return relativedelta(weeks=1 * n) 1099 if unit == "day": 1100 return relativedelta(days=1 * n) 1101 if unit == "hour": 1102 return relativedelta(hours=1 * n) 1103 if unit == "minute": 1104 return relativedelta(minutes=1 * n) 1105 if unit == "second": 1106 return relativedelta(seconds=1 * n) 1107 1108 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1109 1110 1111def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1112 if unit == "year": 1113 return d.replace(month=1, day=1) 1114 if unit == "quarter": 1115 if d.month <= 3: 1116 return d.replace(month=1, day=1) 1117 elif d.month <= 6: 1118 return d.replace(month=4, day=1) 1119 elif d.month <= 9: 1120 return d.replace(month=7, day=1) 1121 else: 1122 return d.replace(month=10, day=1) 1123 if unit == "month": 1124 return d.replace(month=d.month, day=1) 1125 if unit == "week": 1126 # Assuming week starts on Monday (0) and ends on Sunday (6) 1127 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1128 if unit == "day": 1129 return d 1130 1131 raise UnsupportedUnit(f"Unsupported unit: {unit}") 1132 1133 1134def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1135 floor = date_floor(d, unit, dialect) 1136 1137 if floor == d: 1138 return d 1139 1140 return floor + interval(unit) 1141 1142 1143def boolean_literal(condition): 1144 return exp.true() if condition else exp.false() 1145 1146 1147def _flat_simplify(expression, simplifier, root=True): 1148 if root or not expression.same_parent: 1149 operands = [] 1150 queue = deque(expression.flatten(unnest=False)) 1151 size = len(queue) 1152 1153 while queue: 1154 a = queue.popleft() 1155 1156 for b in queue: 1157 result = simplifier(expression, a, b) 1158 1159 if result and result is not expression: 1160 queue.remove(b) 1161 queue.appendleft(result) 1162 break 1163 else: 1164 operands.append(a) 1165 1166 if len(operands) < size: 1167 return functools.reduce( 1168 lambda a, b: expression.__class__(this=a, expression=b), operands 1169 ) 1170 return expression 1171 1172 1173def gen(expression: t.Any) -> str: 1174 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1175 1176 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1177 generator is expensive so we have a bare minimum sql generator here. 1178 """ 1179 return Gen().gen(expression) 1180 1181 1182class Gen: 1183 def __init__(self): 1184 self.stack = [] 1185 self.sqls = [] 1186 1187 def gen(self, expression: exp.Expression) -> str: 1188 self.stack = [expression] 1189 self.sqls.clear() 1190 1191 while self.stack: 1192 node = self.stack.pop() 1193 1194 if isinstance(node, exp.Expression): 1195 exp_handler_name = f"{node.key}_sql" 1196 1197 if hasattr(self, exp_handler_name): 1198 getattr(self, exp_handler_name)(node) 1199 elif isinstance(node, exp.Func): 1200 self._function(node) 1201 else: 1202 key = node.key.upper() 1203 self.stack.append(f"{key} " if self._args(node) else key) 1204 elif type(node) is list: 1205 for n in reversed(node): 1206 if n is not None: 1207 self.stack.extend((n, ",")) 1208 if node: 1209 self.stack.pop() 1210 else: 1211 if node is not None: 1212 self.sqls.append(str(node)) 1213 1214 return "".join(self.sqls) 1215 1216 def add_sql(self, e: exp.Add) -> None: 1217 self._binary(e, " + ") 1218 1219 def alias_sql(self, e: exp.Alias) -> None: 1220 self.stack.extend( 1221 ( 1222 e.args.get("alias"), 1223 " AS ", 1224 e.args.get("this"), 1225 ) 1226 ) 1227 1228 def and_sql(self, e: exp.And) -> None: 1229 self._binary(e, " AND ") 1230 1231 def anonymous_sql(self, e: exp.Anonymous) -> None: 1232 this = e.this 1233 if isinstance(this, str): 1234 name = this.upper() 1235 elif isinstance(this, exp.Identifier): 1236 name = this.this 1237 name = f'"{name}"' if this.quoted else name.upper() 1238 else: 1239 raise ValueError( 1240 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1241 ) 1242 1243 self.stack.extend( 1244 ( 1245 ")", 1246 e.expressions, 1247 "(", 1248 name, 1249 ) 1250 ) 1251 1252 def between_sql(self, e: exp.Between) -> None: 1253 self.stack.extend( 1254 ( 1255 e.args.get("high"), 1256 " AND ", 1257 e.args.get("low"), 1258 " BETWEEN ", 1259 e.this, 1260 ) 1261 ) 1262 1263 def boolean_sql(self, e: exp.Boolean) -> None: 1264 self.stack.append("TRUE" if e.this else "FALSE") 1265 1266 def bracket_sql(self, e: exp.Bracket) -> None: 1267 self.stack.extend( 1268 ( 1269 "]", 1270 e.expressions, 1271 "[", 1272 e.this, 1273 ) 1274 ) 1275 1276 def column_sql(self, e: exp.Column) -> None: 1277 for p in reversed(e.parts): 1278 self.stack.extend((p, ".")) 1279 self.stack.pop() 1280 1281 def datatype_sql(self, e: exp.DataType) -> None: 1282 self._args(e, 1) 1283 self.stack.append(f"{e.this.name} ") 1284 1285 def div_sql(self, e: exp.Div) -> None: 1286 self._binary(e, " / ") 1287 1288 def dot_sql(self, e: exp.Dot) -> None: 1289 self._binary(e, ".") 1290 1291 def eq_sql(self, e: exp.EQ) -> None: 1292 self._binary(e, " = ") 1293 1294 def from_sql(self, e: exp.From) -> None: 1295 self.stack.extend((e.this, "FROM ")) 1296 1297 def gt_sql(self, e: exp.GT) -> None: 1298 self._binary(e, " > ") 1299 1300 def gte_sql(self, e: exp.GTE) -> None: 1301 self._binary(e, " >= ") 1302 1303 def identifier_sql(self, e: exp.Identifier) -> None: 1304 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1305 1306 def ilike_sql(self, e: exp.ILike) -> None: 1307 self._binary(e, " ILIKE ") 1308 1309 def in_sql(self, e: exp.In) -> None: 1310 self.stack.append(")") 1311 self._args(e, 1) 1312 self.stack.extend( 1313 ( 1314 "(", 1315 " IN ", 1316 e.this, 1317 ) 1318 ) 1319 1320 def intdiv_sql(self, e: exp.IntDiv) -> None: 1321 self._binary(e, " DIV ") 1322 1323 def is_sql(self, e: exp.Is) -> None: 1324 self._binary(e, " IS ") 1325 1326 def like_sql(self, e: exp.Like) -> None: 1327 self._binary(e, " Like ") 1328 1329 def literal_sql(self, e: exp.Literal) -> None: 1330 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1331 1332 def lt_sql(self, e: exp.LT) -> None: 1333 self._binary(e, " < ") 1334 1335 def lte_sql(self, e: exp.LTE) -> None: 1336 self._binary(e, " <= ") 1337 1338 def mod_sql(self, e: exp.Mod) -> None: 1339 self._binary(e, " % ") 1340 1341 def mul_sql(self, e: exp.Mul) -> None: 1342 self._binary(e, " * ") 1343 1344 def neg_sql(self, e: exp.Neg) -> None: 1345 self._unary(e, "-") 1346 1347 def neq_sql(self, e: exp.NEQ) -> None: 1348 self._binary(e, " <> ") 1349 1350 def not_sql(self, e: exp.Not) -> None: 1351 self._unary(e, "NOT ") 1352 1353 def null_sql(self, e: exp.Null) -> None: 1354 self.stack.append("NULL") 1355 1356 def or_sql(self, e: exp.Or) -> None: 1357 self._binary(e, " OR ") 1358 1359 def paren_sql(self, e: exp.Paren) -> None: 1360 self.stack.extend( 1361 ( 1362 ")", 1363 e.this, 1364 "(", 1365 ) 1366 ) 1367 1368 def sub_sql(self, e: exp.Sub) -> None: 1369 self._binary(e, " - ") 1370 1371 def subquery_sql(self, e: exp.Subquery) -> None: 1372 self._args(e, 2) 1373 alias = e.args.get("alias") 1374 if alias: 1375 self.stack.append(alias) 1376 self.stack.extend((")", e.this, "(")) 1377 1378 def table_sql(self, e: exp.Table) -> None: 1379 self._args(e, 4) 1380 alias = e.args.get("alias") 1381 if alias: 1382 self.stack.append(alias) 1383 for p in reversed(e.parts): 1384 self.stack.extend((p, ".")) 1385 self.stack.pop() 1386 1387 def tablealias_sql(self, e: exp.TableAlias) -> None: 1388 columns = e.columns 1389 1390 if columns: 1391 self.stack.extend((")", columns, "(")) 1392 1393 self.stack.extend((e.this, " AS ")) 1394 1395 def var_sql(self, e: exp.Var) -> None: 1396 self.stack.append(e.this) 1397 1398 def _binary(self, e: exp.Binary, op: str) -> None: 1399 self.stack.extend((e.expression, op, e.this)) 1400 1401 def _unary(self, e: exp.Unary, op: str) -> None: 1402 self.stack.extend((e.this, op)) 1403 1404 def _function(self, e: exp.Func) -> None: 1405 self.stack.extend( 1406 ( 1407 ")", 1408 list(e.args.values()), 1409 "(", 1410 e.sql_name(), 1411 ) 1412 ) 1413 1414 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1415 kvs = [] 1416 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1417 1418 for k in arg_types or arg_types: 1419 v = node.args.get(k) 1420 1421 if v is not None: 1422 kvs.append([f":{k}", v]) 1423 if kvs: 1424 self.stack.append(kvs) 1425 return True 1426 return False
Common base class for all non-exit exceptions.
Inherited Members
- builtins.Exception
- Exception
- builtins.BaseException
- with_traceback
- args
31def simplify( 32 expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None 33): 34 """ 35 Rewrite sqlglot AST to simplify expressions. 36 37 Example: 38 >>> import sqlglot 39 >>> expression = sqlglot.parse_one("TRUE AND TRUE") 40 >>> simplify(expression).sql() 41 'TRUE' 42 43 Args: 44 expression (sqlglot.Expression): expression to simplify 45 constant_propagation: whether the constant propagation rule should be used 46 47 Returns: 48 sqlglot.Expression: simplified expression 49 """ 50 51 dialect = Dialect.get_or_raise(dialect) 52 53 def _simplify(expression, root=True): 54 if expression.meta.get(FINAL): 55 return expression 56 57 # group by expressions cannot be simplified, for example 58 # select x + 1 + 1 FROM y GROUP BY x + 1 + 1 59 # the projection must exactly match the group by key 60 group = expression.args.get("group") 61 62 if group and hasattr(expression, "selects"): 63 groups = set(group.expressions) 64 group.meta[FINAL] = True 65 66 for e in expression.selects: 67 for node in e.walk(): 68 if node in groups: 69 e.meta[FINAL] = True 70 break 71 72 having = expression.args.get("having") 73 if having: 74 for node in having.walk(): 75 if node in groups: 76 having.meta[FINAL] = True 77 break 78 79 # Pre-order transformations 80 node = expression 81 node = rewrite_between(node) 82 node = uniq_sort(node, root) 83 node = absorb_and_eliminate(node, root) 84 node = simplify_concat(node) 85 node = simplify_conditionals(node) 86 87 if constant_propagation: 88 node = propagate_constants(node, root) 89 90 exp.replace_children(node, lambda e: _simplify(e, False)) 91 92 # Post-order transformations 93 node = simplify_not(node) 94 node = flatten(node) 95 node = simplify_connectors(node, root) 96 node = remove_complements(node, root) 97 node = simplify_coalesce(node) 98 node.parent = expression.parent 99 node = simplify_literals(node, root) 100 node = simplify_equality(node) 101 node = simplify_parens(node) 102 node = simplify_datetrunc(node, dialect) 103 node = sort_comparison(node) 104 node = simplify_startswith(node) 105 106 if root: 107 expression.replace(node) 108 return node 109 110 expression = while_changing(expression, _simplify) 111 remove_where_true(expression) 112 return expression
Rewrite sqlglot AST to simplify expressions.
Example:
>>> import sqlglot >>> expression = sqlglot.parse_one("TRUE AND TRUE") >>> simplify(expression).sql() 'TRUE'
Arguments:
- expression (sqlglot.Expression): expression to simplify
- constant_propagation: whether the constant propagation rule should be used
Returns:
sqlglot.Expression: simplified expression
115def catch(*exceptions): 116 """Decorator that ignores a simplification function if any of `exceptions` are raised""" 117 118 def decorator(func): 119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression 124 125 return wrapped 126 127 return decorator
Decorator that ignores a simplification function if any of exceptions
are raised
130def rewrite_between(expression: exp.Expression) -> exp.Expression: 131 """Rewrite x between y and z to x >= y AND x <= z. 132 133 This is done because comparison simplification is only done on lt/lte/gt/gte. 134 """ 135 if isinstance(expression, exp.Between): 136 negate = isinstance(expression.parent, exp.Not) 137 138 expression = exp.and_( 139 exp.GTE(this=expression.this.copy(), expression=expression.args["low"]), 140 exp.LTE(this=expression.this.copy(), expression=expression.args["high"]), 141 copy=False, 142 ) 143 144 if negate: 145 expression = exp.paren(expression, copy=False) 146 147 return expression
Rewrite x between y and z to x >= y AND x <= z.
This is done because comparison simplification is only done on lt/lte/gt/gte.
160def simplify_not(expression): 161 """ 162 Demorgan's Law 163 NOT (x OR y) -> NOT x AND NOT y 164 NOT (x AND y) -> NOT x OR NOT y 165 """ 166 if isinstance(expression, exp.Not): 167 this = expression.this 168 if is_null(this): 169 return exp.null() 170 if this.__class__ in COMPLEMENT_COMPARISONS: 171 return COMPLEMENT_COMPARISONS[this.__class__]( 172 this=this.this, expression=this.expression 173 ) 174 if isinstance(this, exp.Paren): 175 condition = this.unnest() 176 if isinstance(condition, exp.And): 177 return exp.paren( 178 exp.or_( 179 exp.not_(condition.left, copy=False), 180 exp.not_(condition.right, copy=False), 181 copy=False, 182 ) 183 ) 184 if isinstance(condition, exp.Or): 185 return exp.paren( 186 exp.and_( 187 exp.not_(condition.left, copy=False), 188 exp.not_(condition.right, copy=False), 189 copy=False, 190 ) 191 ) 192 if is_null(condition): 193 return exp.null() 194 if always_true(this): 195 return exp.false() 196 if is_false(this): 197 return exp.true() 198 if isinstance(this, exp.Not): 199 # double negation 200 # NOT NOT x -> x 201 return this.this 202 return expression
Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y
205def flatten(expression): 206 """ 207 A AND (B AND C) -> A AND B AND C 208 A OR (B OR C) -> A OR B OR C 209 """ 210 if isinstance(expression, exp.Connector): 211 for node in expression.args.values(): 212 child = node.unnest() 213 if isinstance(child, expression.__class__): 214 node.replace(child) 215 return expression
A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C
218def simplify_connectors(expression, root=True): 219 def _simplify_connectors(expression, left, right): 220 if left == right: 221 return left 222 if isinstance(expression, exp.And): 223 if is_false(left) or is_false(right): 224 return exp.false() 225 if is_null(left) or is_null(right): 226 return exp.null() 227 if always_true(left) and always_true(right): 228 return exp.true() 229 if always_true(left): 230 return right 231 if always_true(right): 232 return left 233 return _simplify_comparison(expression, left, right) 234 elif isinstance(expression, exp.Or): 235 if always_true(left) or always_true(right): 236 return exp.true() 237 if is_false(left) and is_false(right): 238 return exp.false() 239 if ( 240 (is_null(left) and is_null(right)) 241 or (is_null(left) and is_false(right)) 242 or (is_false(left) and is_null(right)) 243 ): 244 return exp.null() 245 if is_false(left): 246 return right 247 if is_false(right): 248 return left 249 return _simplify_comparison(expression, left, right, or_=True) 250 251 if isinstance(expression, exp.Connector): 252 return _flat_simplify(expression, _simplify_connectors, root) 253 return expression
337def remove_complements(expression, root=True): 338 """ 339 Removing complements. 340 341 A AND NOT A -> FALSE 342 A OR NOT A -> TRUE 343 """ 344 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 345 complement = exp.false() if isinstance(expression, exp.And) else exp.true() 346 347 for a, b in itertools.permutations(expression.flatten(), 2): 348 if is_complement(a, b): 349 return complement 350 return expression
Removing complements.
A AND NOT A -> FALSE A OR NOT A -> TRUE
353def uniq_sort(expression, root=True): 354 """ 355 Uniq and sort a connector. 356 357 C AND A AND B AND B -> A AND B AND C 358 """ 359 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 360 result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_ 361 flattened = tuple(expression.flatten()) 362 deduped = {gen(e): e for e in flattened} 363 arr = tuple(deduped.items()) 364 365 # check if the operands are already sorted, if not sort them 366 # A AND C AND B -> A AND B AND C 367 for i, (sql, e) in enumerate(arr[1:]): 368 if sql < arr[i][0]: 369 expression = result_func(*(e for _, e in sorted(arr)), copy=False) 370 break 371 else: 372 # we didn't have to sort but maybe we need to dedup 373 if len(deduped) < len(flattened): 374 expression = result_func(*deduped.values(), copy=False) 375 376 return expression
Uniq and sort a connector.
C AND A AND B AND B -> A AND B AND C
379def absorb_and_eliminate(expression, root=True): 380 """ 381 absorption: 382 A AND (A OR B) -> A 383 A OR (A AND B) -> A 384 A AND (NOT A OR B) -> A AND B 385 A OR (NOT A AND B) -> A OR B 386 elimination: 387 (A AND B) OR (A AND NOT B) -> A 388 (A OR B) AND (A OR NOT B) -> A 389 """ 390 if isinstance(expression, exp.Connector) and (root or not expression.same_parent): 391 kind = exp.Or if isinstance(expression, exp.And) else exp.And 392 393 for a, b in itertools.permutations(expression.flatten(), 2): 394 if isinstance(a, kind): 395 aa, ab = a.unnest_operands() 396 397 # absorb 398 if is_complement(b, aa): 399 aa.replace(exp.true() if kind == exp.And else exp.false()) 400 elif is_complement(b, ab): 401 ab.replace(exp.true() if kind == exp.And else exp.false()) 402 elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()): 403 a.replace(exp.false() if kind == exp.And else exp.true()) 404 elif isinstance(b, kind): 405 # eliminate 406 rhs = b.unnest_operands() 407 ba, bb = rhs 408 409 if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)): 410 a.replace(aa) 411 b.replace(aa) 412 elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)): 413 a.replace(ab) 414 b.replace(ab) 415 416 return expression
absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A
419def propagate_constants(expression, root=True): 420 """ 421 Propagate constants for conjunctions in DNF: 422 423 SELECT * FROM t WHERE a = b AND b = 5 becomes 424 SELECT * FROM t WHERE a = 5 AND b = 5 425 426 Reference: https://www.sqlite.org/optoverview.html 427 """ 428 429 if ( 430 isinstance(expression, exp.And) 431 and (root or not expression.same_parent) 432 and sqlglot.optimizer.normalize.normalized(expression, dnf=True) 433 ): 434 constant_mapping = {} 435 for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)): 436 if isinstance(expr, exp.EQ): 437 l, r = expr.left, expr.right 438 439 # TODO: create a helper that can be used to detect nested literal expressions such 440 # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too 441 if isinstance(l, exp.Column) and isinstance(r, exp.Literal): 442 constant_mapping[l] = (id(l), r) 443 444 if constant_mapping: 445 for column in find_all_in_scope(expression, exp.Column): 446 parent = column.parent 447 column_id, constant = constant_mapping.get(column) or (None, None) 448 if ( 449 column_id is not None 450 and id(column) != column_id 451 and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null)) 452 ): 453 column.replace(constant.copy()) 454 455 return expression
Propagate constants for conjunctions in DNF:
SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5
Reference: https://www.sqlite.org/optoverview.html
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Use the subtraction and addition properties of equality to simplify expressions:
x + 1 = 3 becomes x = 2
There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:
l r
x + 1 = 3
a b
530def simplify_literals(expression, root=True): 531 if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector): 532 return _flat_simplify(expression, _simplify_binary, root) 533 534 if isinstance(expression, exp.Neg): 535 this = expression.this 536 if this.is_number: 537 value = this.name 538 if value[0] == "-": 539 return exp.Literal.number(value[1:]) 540 return exp.Literal.number(f"-{value}") 541 542 if type(expression) in INVERSE_DATE_OPS: 543 return _simplify_binary(expression, expression.this, expression.interval()) or expression 544 545 return expression
619def simplify_parens(expression): 620 if not isinstance(expression, exp.Paren): 621 return expression 622 623 this = expression.this 624 parent = expression.parent 625 parent_is_predicate = isinstance(parent, exp.Predicate) 626 627 if not isinstance(this, exp.Select) and ( 628 not isinstance(parent, (exp.Condition, exp.Binary)) 629 or isinstance(parent, exp.Paren) 630 or ( 631 not isinstance(this, exp.Binary) 632 and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate) 633 ) 634 or (isinstance(this, exp.Predicate) and not parent_is_predicate) 635 or (isinstance(this, exp.Add) and isinstance(parent, exp.Add)) 636 or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul)) 637 or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub))) 638 ): 639 return this 640 return expression
663def simplify_coalesce(expression): 664 # COALESCE(x) -> x 665 if ( 666 isinstance(expression, exp.Coalesce) 667 and (not expression.expressions or _is_nonnull_constant(expression.this)) 668 # COALESCE is also used as a Spark partitioning hint 669 and not isinstance(expression.parent, exp.Hint) 670 ): 671 return expression.this 672 673 if not isinstance(expression, COMPARISONS): 674 return expression 675 676 if isinstance(expression.left, exp.Coalesce): 677 coalesce = expression.left 678 other = expression.right 679 elif isinstance(expression.right, exp.Coalesce): 680 coalesce = expression.right 681 other = expression.left 682 else: 683 return expression 684 685 # This transformation is valid for non-constants, 686 # but it really only does anything if they are both constants. 687 if not _is_constant(other): 688 return expression 689 690 # Find the first constant arg 691 for arg_index, arg in enumerate(coalesce.expressions): 692 if _is_constant(arg): 693 break 694 else: 695 return expression 696 697 coalesce.set("expressions", coalesce.expressions[:arg_index]) 698 699 # Remove the COALESCE function. This is an optimization, skipping a simplify iteration, 700 # since we already remove COALESCE at the top of this function. 701 coalesce = coalesce if coalesce.expressions else coalesce.this 702 703 # This expression is more complex than when we started, but it will get simplified further 704 return exp.paren( 705 exp.or_( 706 exp.and_( 707 coalesce.is_(exp.null()).not_(copy=False), 708 expression.copy(), 709 copy=False, 710 ), 711 exp.and_( 712 coalesce.is_(exp.null()), 713 type(expression)(this=arg.copy(), expression=other.copy()), 714 copy=False, 715 ), 716 copy=False, 717 ) 718 )
724def simplify_concat(expression): 725 """Reduces all groups that contain string literals by concatenating them.""" 726 if not isinstance(expression, CONCATS) or ( 727 # We can't reduce a CONCAT_WS call if we don't statically know the separator 728 isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string 729 ): 730 return expression 731 732 if isinstance(expression, exp.ConcatWs): 733 sep_expr, *expressions = expression.expressions 734 sep = sep_expr.name 735 concat_type = exp.ConcatWs 736 args = {} 737 else: 738 expressions = expression.expressions 739 sep = "" 740 concat_type = exp.Concat 741 args = { 742 "safe": expression.args.get("safe"), 743 "coalesce": expression.args.get("coalesce"), 744 } 745 746 new_args = [] 747 for is_string_group, group in itertools.groupby( 748 expressions or expression.flatten(), lambda e: e.is_string 749 ): 750 if is_string_group: 751 new_args.append(exp.Literal.string(sep.join(string.name for string in group))) 752 else: 753 new_args.extend(group) 754 755 if len(new_args) == 1 and new_args[0].is_string: 756 return new_args[0] 757 758 if concat_type is exp.ConcatWs: 759 new_args = [sep_expr] + new_args 760 761 return concat_type(expressions=new_args, **args)
Reduces all groups that contain string literals by concatenating them.
764def simplify_conditionals(expression): 765 """Simplifies expressions like IF, CASE if their condition is statically known.""" 766 if isinstance(expression, exp.Case): 767 this = expression.this 768 for case in expression.args["ifs"]: 769 cond = case.this 770 if this: 771 # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ... 772 cond = cond.replace(this.pop().eq(cond)) 773 774 if always_true(cond): 775 return case.args["true"] 776 777 if always_false(cond): 778 case.pop() 779 if not expression.args["ifs"]: 780 return expression.args.get("default") or exp.null() 781 elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case): 782 if always_true(expression.this): 783 return expression.args["true"] 784 if always_false(expression.this): 785 return expression.args.get("false") or exp.null() 786 787 return expression
Simplifies expressions like IF, CASE if their condition is statically known.
790def simplify_startswith(expression: exp.Expression) -> exp.Expression: 791 """ 792 Reduces a prefix check to either TRUE or FALSE if both the string and the 793 prefix are statically known. 794 795 Example: 796 >>> from sqlglot import parse_one 797 >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 798 'TRUE' 799 """ 800 if ( 801 isinstance(expression, exp.StartsWith) 802 and expression.this.is_string 803 and expression.expression.is_string 804 ): 805 return exp.convert(expression.name.startswith(expression.expression.name)) 806 807 return expression
Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.
Example:
>>> from sqlglot import parse_one >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql() 'TRUE'
119 def wrapped(expression, *args, **kwargs): 120 try: 121 return func(expression, *args, **kwargs) 122 except exceptions: 123 return expression
Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)
934def sort_comparison(expression: exp.Expression) -> exp.Expression: 935 if expression.__class__ in COMPLEMENT_COMPARISONS: 936 l, r = expression.this, expression.expression 937 l_column = isinstance(l, exp.Column) 938 r_column = isinstance(r, exp.Column) 939 l_const = _is_constant(l) 940 r_const = _is_constant(r) 941 942 if (l_column and not r_column) or (r_const and not l_const): 943 return expression 944 if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)): 945 return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)( 946 this=r, expression=l 947 ) 948 return expression
962def remove_where_true(expression): 963 for where in expression.find_all(exp.Where): 964 if always_true(where.this): 965 where.pop() 966 for join in expression.find_all(exp.Join): 967 if ( 968 always_true(join.args.get("on")) 969 and not join.args.get("using") 970 and not join.args.get("method") 971 and (join.side, join.kind) in JOINS 972 ): 973 join.args["on"].pop() 974 join.set("side", None) 975 join.set("kind", "CROSS")
1000def eval_boolean(expression, a, b): 1001 if isinstance(expression, (exp.EQ, exp.Is)): 1002 return boolean_literal(a == b) 1003 if isinstance(expression, exp.NEQ): 1004 return boolean_literal(a != b) 1005 if isinstance(expression, exp.GT): 1006 return boolean_literal(a > b) 1007 if isinstance(expression, exp.GTE): 1008 return boolean_literal(a >= b) 1009 if isinstance(expression, exp.LT): 1010 return boolean_literal(a < b) 1011 if isinstance(expression, exp.LTE): 1012 return boolean_literal(a <= b) 1013 return None
1016def cast_as_date(value: t.Any) -> t.Optional[datetime.date]: 1017 if isinstance(value, datetime.datetime): 1018 return value.date() 1019 if isinstance(value, datetime.date): 1020 return value 1021 try: 1022 return datetime.datetime.fromisoformat(value).date() 1023 except ValueError: 1024 return None
1027def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]: 1028 if isinstance(value, datetime.datetime): 1029 return value 1030 if isinstance(value, datetime.date): 1031 return datetime.datetime(year=value.year, month=value.month, day=value.day) 1032 try: 1033 return datetime.datetime.fromisoformat(value) 1034 except ValueError: 1035 return None
1038def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1039 if not value: 1040 return None 1041 if to.is_type(exp.DataType.Type.DATE): 1042 return cast_as_date(value) 1043 if to.is_type(*exp.DataType.TEMPORAL_TYPES): 1044 return cast_as_datetime(value) 1045 return None
1048def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]: 1049 if isinstance(cast, exp.Cast): 1050 to = cast.to 1051 elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"): 1052 to = exp.DataType.build(exp.DataType.Type.DATE) 1053 else: 1054 return None 1055 1056 if isinstance(cast.this, exp.Literal): 1057 value: t.Any = cast.this.name 1058 elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)): 1059 value = extract_date(cast.this) 1060 else: 1061 return None 1062 return cast_value(value, to)
1089def interval(unit: str, n: int = 1): 1090 from dateutil.relativedelta import relativedelta 1091 1092 if unit == "year": 1093 return relativedelta(years=1 * n) 1094 if unit == "quarter": 1095 return relativedelta(months=3 * n) 1096 if unit == "month": 1097 return relativedelta(months=1 * n) 1098 if unit == "week": 1099 return relativedelta(weeks=1 * n) 1100 if unit == "day": 1101 return relativedelta(days=1 * n) 1102 if unit == "hour": 1103 return relativedelta(hours=1 * n) 1104 if unit == "minute": 1105 return relativedelta(minutes=1 * n) 1106 if unit == "second": 1107 return relativedelta(seconds=1 * n) 1108 1109 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1112def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date: 1113 if unit == "year": 1114 return d.replace(month=1, day=1) 1115 if unit == "quarter": 1116 if d.month <= 3: 1117 return d.replace(month=1, day=1) 1118 elif d.month <= 6: 1119 return d.replace(month=4, day=1) 1120 elif d.month <= 9: 1121 return d.replace(month=7, day=1) 1122 else: 1123 return d.replace(month=10, day=1) 1124 if unit == "month": 1125 return d.replace(month=d.month, day=1) 1126 if unit == "week": 1127 # Assuming week starts on Monday (0) and ends on Sunday (6) 1128 return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET) 1129 if unit == "day": 1130 return d 1131 1132 raise UnsupportedUnit(f"Unsupported unit: {unit}")
1174def gen(expression: t.Any) -> str: 1175 """Simple pseudo sql generator for quickly generating sortable and uniq strings. 1176 1177 Sorting and deduping sql is a necessary step for optimization. Calling the actual 1178 generator is expensive so we have a bare minimum sql generator here. 1179 """ 1180 return Gen().gen(expression)
Simple pseudo sql generator for quickly generating sortable and uniq strings.
Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.
1183class Gen: 1184 def __init__(self): 1185 self.stack = [] 1186 self.sqls = [] 1187 1188 def gen(self, expression: exp.Expression) -> str: 1189 self.stack = [expression] 1190 self.sqls.clear() 1191 1192 while self.stack: 1193 node = self.stack.pop() 1194 1195 if isinstance(node, exp.Expression): 1196 exp_handler_name = f"{node.key}_sql" 1197 1198 if hasattr(self, exp_handler_name): 1199 getattr(self, exp_handler_name)(node) 1200 elif isinstance(node, exp.Func): 1201 self._function(node) 1202 else: 1203 key = node.key.upper() 1204 self.stack.append(f"{key} " if self._args(node) else key) 1205 elif type(node) is list: 1206 for n in reversed(node): 1207 if n is not None: 1208 self.stack.extend((n, ",")) 1209 if node: 1210 self.stack.pop() 1211 else: 1212 if node is not None: 1213 self.sqls.append(str(node)) 1214 1215 return "".join(self.sqls) 1216 1217 def add_sql(self, e: exp.Add) -> None: 1218 self._binary(e, " + ") 1219 1220 def alias_sql(self, e: exp.Alias) -> None: 1221 self.stack.extend( 1222 ( 1223 e.args.get("alias"), 1224 " AS ", 1225 e.args.get("this"), 1226 ) 1227 ) 1228 1229 def and_sql(self, e: exp.And) -> None: 1230 self._binary(e, " AND ") 1231 1232 def anonymous_sql(self, e: exp.Anonymous) -> None: 1233 this = e.this 1234 if isinstance(this, str): 1235 name = this.upper() 1236 elif isinstance(this, exp.Identifier): 1237 name = this.this 1238 name = f'"{name}"' if this.quoted else name.upper() 1239 else: 1240 raise ValueError( 1241 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1242 ) 1243 1244 self.stack.extend( 1245 ( 1246 ")", 1247 e.expressions, 1248 "(", 1249 name, 1250 ) 1251 ) 1252 1253 def between_sql(self, e: exp.Between) -> None: 1254 self.stack.extend( 1255 ( 1256 e.args.get("high"), 1257 " AND ", 1258 e.args.get("low"), 1259 " BETWEEN ", 1260 e.this, 1261 ) 1262 ) 1263 1264 def boolean_sql(self, e: exp.Boolean) -> None: 1265 self.stack.append("TRUE" if e.this else "FALSE") 1266 1267 def bracket_sql(self, e: exp.Bracket) -> None: 1268 self.stack.extend( 1269 ( 1270 "]", 1271 e.expressions, 1272 "[", 1273 e.this, 1274 ) 1275 ) 1276 1277 def column_sql(self, e: exp.Column) -> None: 1278 for p in reversed(e.parts): 1279 self.stack.extend((p, ".")) 1280 self.stack.pop() 1281 1282 def datatype_sql(self, e: exp.DataType) -> None: 1283 self._args(e, 1) 1284 self.stack.append(f"{e.this.name} ") 1285 1286 def div_sql(self, e: exp.Div) -> None: 1287 self._binary(e, " / ") 1288 1289 def dot_sql(self, e: exp.Dot) -> None: 1290 self._binary(e, ".") 1291 1292 def eq_sql(self, e: exp.EQ) -> None: 1293 self._binary(e, " = ") 1294 1295 def from_sql(self, e: exp.From) -> None: 1296 self.stack.extend((e.this, "FROM ")) 1297 1298 def gt_sql(self, e: exp.GT) -> None: 1299 self._binary(e, " > ") 1300 1301 def gte_sql(self, e: exp.GTE) -> None: 1302 self._binary(e, " >= ") 1303 1304 def identifier_sql(self, e: exp.Identifier) -> None: 1305 self.stack.append(f'"{e.this}"' if e.quoted else e.this) 1306 1307 def ilike_sql(self, e: exp.ILike) -> None: 1308 self._binary(e, " ILIKE ") 1309 1310 def in_sql(self, e: exp.In) -> None: 1311 self.stack.append(")") 1312 self._args(e, 1) 1313 self.stack.extend( 1314 ( 1315 "(", 1316 " IN ", 1317 e.this, 1318 ) 1319 ) 1320 1321 def intdiv_sql(self, e: exp.IntDiv) -> None: 1322 self._binary(e, " DIV ") 1323 1324 def is_sql(self, e: exp.Is) -> None: 1325 self._binary(e, " IS ") 1326 1327 def like_sql(self, e: exp.Like) -> None: 1328 self._binary(e, " Like ") 1329 1330 def literal_sql(self, e: exp.Literal) -> None: 1331 self.stack.append(f"'{e.this}'" if e.is_string else e.this) 1332 1333 def lt_sql(self, e: exp.LT) -> None: 1334 self._binary(e, " < ") 1335 1336 def lte_sql(self, e: exp.LTE) -> None: 1337 self._binary(e, " <= ") 1338 1339 def mod_sql(self, e: exp.Mod) -> None: 1340 self._binary(e, " % ") 1341 1342 def mul_sql(self, e: exp.Mul) -> None: 1343 self._binary(e, " * ") 1344 1345 def neg_sql(self, e: exp.Neg) -> None: 1346 self._unary(e, "-") 1347 1348 def neq_sql(self, e: exp.NEQ) -> None: 1349 self._binary(e, " <> ") 1350 1351 def not_sql(self, e: exp.Not) -> None: 1352 self._unary(e, "NOT ") 1353 1354 def null_sql(self, e: exp.Null) -> None: 1355 self.stack.append("NULL") 1356 1357 def or_sql(self, e: exp.Or) -> None: 1358 self._binary(e, " OR ") 1359 1360 def paren_sql(self, e: exp.Paren) -> None: 1361 self.stack.extend( 1362 ( 1363 ")", 1364 e.this, 1365 "(", 1366 ) 1367 ) 1368 1369 def sub_sql(self, e: exp.Sub) -> None: 1370 self._binary(e, " - ") 1371 1372 def subquery_sql(self, e: exp.Subquery) -> None: 1373 self._args(e, 2) 1374 alias = e.args.get("alias") 1375 if alias: 1376 self.stack.append(alias) 1377 self.stack.extend((")", e.this, "(")) 1378 1379 def table_sql(self, e: exp.Table) -> None: 1380 self._args(e, 4) 1381 alias = e.args.get("alias") 1382 if alias: 1383 self.stack.append(alias) 1384 for p in reversed(e.parts): 1385 self.stack.extend((p, ".")) 1386 self.stack.pop() 1387 1388 def tablealias_sql(self, e: exp.TableAlias) -> None: 1389 columns = e.columns 1390 1391 if columns: 1392 self.stack.extend((")", columns, "(")) 1393 1394 self.stack.extend((e.this, " AS ")) 1395 1396 def var_sql(self, e: exp.Var) -> None: 1397 self.stack.append(e.this) 1398 1399 def _binary(self, e: exp.Binary, op: str) -> None: 1400 self.stack.extend((e.expression, op, e.this)) 1401 1402 def _unary(self, e: exp.Unary, op: str) -> None: 1403 self.stack.extend((e.this, op)) 1404 1405 def _function(self, e: exp.Func) -> None: 1406 self.stack.extend( 1407 ( 1408 ")", 1409 list(e.args.values()), 1410 "(", 1411 e.sql_name(), 1412 ) 1413 ) 1414 1415 def _args(self, node: exp.Expression, arg_index: int = 0) -> bool: 1416 kvs = [] 1417 arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types 1418 1419 for k in arg_types or arg_types: 1420 v = node.args.get(k) 1421 1422 if v is not None: 1423 kvs.append([f":{k}", v]) 1424 if kvs: 1425 self.stack.append(kvs) 1426 return True 1427 return False
1188 def gen(self, expression: exp.Expression) -> str: 1189 self.stack = [expression] 1190 self.sqls.clear() 1191 1192 while self.stack: 1193 node = self.stack.pop() 1194 1195 if isinstance(node, exp.Expression): 1196 exp_handler_name = f"{node.key}_sql" 1197 1198 if hasattr(self, exp_handler_name): 1199 getattr(self, exp_handler_name)(node) 1200 elif isinstance(node, exp.Func): 1201 self._function(node) 1202 else: 1203 key = node.key.upper() 1204 self.stack.append(f"{key} " if self._args(node) else key) 1205 elif type(node) is list: 1206 for n in reversed(node): 1207 if n is not None: 1208 self.stack.extend((n, ",")) 1209 if node: 1210 self.stack.pop() 1211 else: 1212 if node is not None: 1213 self.sqls.append(str(node)) 1214 1215 return "".join(self.sqls)
1232 def anonymous_sql(self, e: exp.Anonymous) -> None: 1233 this = e.this 1234 if isinstance(this, str): 1235 name = this.upper() 1236 elif isinstance(this, exp.Identifier): 1237 name = this.this 1238 name = f'"{name}"' if this.quoted else name.upper() 1239 else: 1240 raise ValueError( 1241 f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'." 1242 ) 1243 1244 self.stack.extend( 1245 ( 1246 ")", 1247 e.expressions, 1248 "(", 1249 name, 1250 ) 1251 )