Edit on GitHub

sqlglot.optimizer.simplify

   1from __future__ import annotations
   2
   3import datetime
   4import functools
   5import itertools
   6import typing as t
   7from collections import deque
   8from decimal import Decimal
   9
  10import sqlglot
  11from sqlglot import Dialect, exp
  12from sqlglot.helper import first, merge_ranges, while_changing
  13from sqlglot.optimizer.scope import find_all_in_scope, walk_in_scope
  14
  15if t.TYPE_CHECKING:
  16    from sqlglot.dialects.dialect import DialectType
  17
  18    DateTruncBinaryTransform = t.Callable[
  19        [exp.Expression, datetime.date, str, Dialect], t.Optional[exp.Expression]
  20    ]
  21
  22# Final means that an expression should not be simplified
  23FINAL = "final"
  24
  25
  26class UnsupportedUnit(Exception):
  27    pass
  28
  29
  30def simplify(
  31    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
  32):
  33    """
  34    Rewrite sqlglot AST to simplify expressions.
  35
  36    Example:
  37        >>> import sqlglot
  38        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
  39        >>> simplify(expression).sql()
  40        'TRUE'
  41
  42    Args:
  43        expression (sqlglot.Expression): expression to simplify
  44        constant_propagation: whether the constant propagation rule should be used
  45
  46    Returns:
  47        sqlglot.Expression: simplified expression
  48    """
  49
  50    dialect = Dialect.get_or_raise(dialect)
  51
  52    def _simplify(expression, root=True):
  53        if expression.meta.get(FINAL):
  54            return expression
  55
  56        # group by expressions cannot be simplified, for example
  57        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
  58        # the projection must exactly match the group by key
  59        group = expression.args.get("group")
  60
  61        if group and hasattr(expression, "selects"):
  62            groups = set(group.expressions)
  63            group.meta[FINAL] = True
  64
  65            for e in expression.selects:
  66                for node in e.walk():
  67                    if node in groups:
  68                        e.meta[FINAL] = True
  69                        break
  70
  71            having = expression.args.get("having")
  72            if having:
  73                for node in having.walk():
  74                    if node in groups:
  75                        having.meta[FINAL] = True
  76                        break
  77
  78        # Pre-order transformations
  79        node = expression
  80        node = rewrite_between(node)
  81        node = uniq_sort(node, root)
  82        node = absorb_and_eliminate(node, root)
  83        node = simplify_concat(node)
  84        node = simplify_conditionals(node)
  85
  86        if constant_propagation:
  87            node = propagate_constants(node, root)
  88
  89        exp.replace_children(node, lambda e: _simplify(e, False))
  90
  91        # Post-order transformations
  92        node = simplify_not(node)
  93        node = flatten(node)
  94        node = simplify_connectors(node, root)
  95        node = remove_complements(node, root)
  96        node = simplify_coalesce(node)
  97        node.parent = expression.parent
  98        node = simplify_literals(node, root)
  99        node = simplify_equality(node)
 100        node = simplify_parens(node)
 101        node = simplify_datetrunc(node, dialect)
 102        node = sort_comparison(node)
 103        node = simplify_startswith(node)
 104
 105        if root:
 106            expression.replace(node)
 107        return node
 108
 109    expression = while_changing(expression, _simplify)
 110    remove_where_true(expression)
 111    return expression
 112
 113
 114def catch(*exceptions):
 115    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
 116
 117    def decorator(func):
 118        def wrapped(expression, *args, **kwargs):
 119            try:
 120                return func(expression, *args, **kwargs)
 121            except exceptions:
 122                return expression
 123
 124        return wrapped
 125
 126    return decorator
 127
 128
 129def rewrite_between(expression: exp.Expression) -> exp.Expression:
 130    """Rewrite x between y and z to x >= y AND x <= z.
 131
 132    This is done because comparison simplification is only done on lt/lte/gt/gte.
 133    """
 134    if isinstance(expression, exp.Between):
 135        negate = isinstance(expression.parent, exp.Not)
 136
 137        expression = exp.and_(
 138            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
 139            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
 140            copy=False,
 141        )
 142
 143        if negate:
 144            expression = exp.paren(expression, copy=False)
 145
 146    return expression
 147
 148
 149COMPLEMENT_COMPARISONS = {
 150    exp.LT: exp.GTE,
 151    exp.GT: exp.LTE,
 152    exp.LTE: exp.GT,
 153    exp.GTE: exp.LT,
 154    exp.EQ: exp.NEQ,
 155    exp.NEQ: exp.EQ,
 156}
 157
 158
 159def simplify_not(expression):
 160    """
 161    Demorgan's Law
 162    NOT (x OR y) -> NOT x AND NOT y
 163    NOT (x AND y) -> NOT x OR NOT y
 164    """
 165    if isinstance(expression, exp.Not):
 166        this = expression.this
 167        if is_null(this):
 168            return exp.null()
 169        if this.__class__ in COMPLEMENT_COMPARISONS:
 170            return COMPLEMENT_COMPARISONS[this.__class__](
 171                this=this.this, expression=this.expression
 172            )
 173        if isinstance(this, exp.Paren):
 174            condition = this.unnest()
 175            if isinstance(condition, exp.And):
 176                return exp.paren(
 177                    exp.or_(
 178                        exp.not_(condition.left, copy=False),
 179                        exp.not_(condition.right, copy=False),
 180                        copy=False,
 181                    )
 182                )
 183            if isinstance(condition, exp.Or):
 184                return exp.paren(
 185                    exp.and_(
 186                        exp.not_(condition.left, copy=False),
 187                        exp.not_(condition.right, copy=False),
 188                        copy=False,
 189                    )
 190                )
 191            if is_null(condition):
 192                return exp.null()
 193        if always_true(this):
 194            return exp.false()
 195        if is_false(this):
 196            return exp.true()
 197        if isinstance(this, exp.Not):
 198            # double negation
 199            # NOT NOT x -> x
 200            return this.this
 201    return expression
 202
 203
 204def flatten(expression):
 205    """
 206    A AND (B AND C) -> A AND B AND C
 207    A OR (B OR C) -> A OR B OR C
 208    """
 209    if isinstance(expression, exp.Connector):
 210        for node in expression.args.values():
 211            child = node.unnest()
 212            if isinstance(child, expression.__class__):
 213                node.replace(child)
 214    return expression
 215
 216
 217def simplify_connectors(expression, root=True):
 218    def _simplify_connectors(expression, left, right):
 219        if left == right:
 220            return left
 221        if isinstance(expression, exp.And):
 222            if is_false(left) or is_false(right):
 223                return exp.false()
 224            if is_null(left) or is_null(right):
 225                return exp.null()
 226            if always_true(left) and always_true(right):
 227                return exp.true()
 228            if always_true(left):
 229                return right
 230            if always_true(right):
 231                return left
 232            return _simplify_comparison(expression, left, right)
 233        elif isinstance(expression, exp.Or):
 234            if always_true(left) or always_true(right):
 235                return exp.true()
 236            if is_false(left) and is_false(right):
 237                return exp.false()
 238            if (
 239                (is_null(left) and is_null(right))
 240                or (is_null(left) and is_false(right))
 241                or (is_false(left) and is_null(right))
 242            ):
 243                return exp.null()
 244            if is_false(left):
 245                return right
 246            if is_false(right):
 247                return left
 248            return _simplify_comparison(expression, left, right, or_=True)
 249
 250    if isinstance(expression, exp.Connector):
 251        return _flat_simplify(expression, _simplify_connectors, root)
 252    return expression
 253
 254
 255LT_LTE = (exp.LT, exp.LTE)
 256GT_GTE = (exp.GT, exp.GTE)
 257
 258COMPARISONS = (
 259    *LT_LTE,
 260    *GT_GTE,
 261    exp.EQ,
 262    exp.NEQ,
 263    exp.Is,
 264)
 265
 266INVERSE_COMPARISONS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 267    exp.LT: exp.GT,
 268    exp.GT: exp.LT,
 269    exp.LTE: exp.GTE,
 270    exp.GTE: exp.LTE,
 271}
 272
 273NONDETERMINISTIC = (exp.Rand, exp.Randn)
 274
 275
 276def _simplify_comparison(expression, left, right, or_=False):
 277    if isinstance(left, COMPARISONS) and isinstance(right, COMPARISONS):
 278        ll, lr = left.args.values()
 279        rl, rr = right.args.values()
 280
 281        largs = {ll, lr}
 282        rargs = {rl, rr}
 283
 284        matching = largs & rargs
 285        columns = {m for m in matching if not _is_constant(m) and not m.find(*NONDETERMINISTIC)}
 286
 287        if matching and columns:
 288            try:
 289                l = first(largs - columns)
 290                r = first(rargs - columns)
 291            except StopIteration:
 292                return expression
 293
 294            if l.is_number and r.is_number:
 295                l = float(l.name)
 296                r = float(r.name)
 297            elif l.is_string and r.is_string:
 298                l = l.name
 299                r = r.name
 300            else:
 301                l = extract_date(l)
 302                if not l:
 303                    return None
 304                r = extract_date(r)
 305                if not r:
 306                    return None
 307
 308            for (a, av), (b, bv) in itertools.permutations(((left, l), (right, r))):
 309                if isinstance(a, LT_LTE) and isinstance(b, LT_LTE):
 310                    return left if (av > bv if or_ else av <= bv) else right
 311                if isinstance(a, GT_GTE) and isinstance(b, GT_GTE):
 312                    return left if (av < bv if or_ else av >= bv) else right
 313
 314                # we can't ever shortcut to true because the column could be null
 315                if not or_:
 316                    if isinstance(a, exp.LT) and isinstance(b, GT_GTE):
 317                        if av <= bv:
 318                            return exp.false()
 319                    elif isinstance(a, exp.GT) and isinstance(b, LT_LTE):
 320                        if av >= bv:
 321                            return exp.false()
 322                    elif isinstance(a, exp.EQ):
 323                        if isinstance(b, exp.LT):
 324                            return exp.false() if av >= bv else a
 325                        if isinstance(b, exp.LTE):
 326                            return exp.false() if av > bv else a
 327                        if isinstance(b, exp.GT):
 328                            return exp.false() if av <= bv else a
 329                        if isinstance(b, exp.GTE):
 330                            return exp.false() if av < bv else a
 331                        if isinstance(b, exp.NEQ):
 332                            return exp.false() if av == bv else a
 333    return None
 334
 335
 336def remove_complements(expression, root=True):
 337    """
 338    Removing complements.
 339
 340    A AND NOT A -> FALSE
 341    A OR NOT A -> TRUE
 342    """
 343    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 344        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
 345
 346        for a, b in itertools.permutations(expression.flatten(), 2):
 347            if is_complement(a, b):
 348                return complement
 349    return expression
 350
 351
 352def uniq_sort(expression, root=True):
 353    """
 354    Uniq and sort a connector.
 355
 356    C AND A AND B AND B -> A AND B AND C
 357    """
 358    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 359        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
 360        flattened = tuple(expression.flatten())
 361        deduped = {gen(e): e for e in flattened}
 362        arr = tuple(deduped.items())
 363
 364        # check if the operands are already sorted, if not sort them
 365        # A AND C AND B -> A AND B AND C
 366        for i, (sql, e) in enumerate(arr[1:]):
 367            if sql < arr[i][0]:
 368                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
 369                break
 370        else:
 371            # we didn't have to sort but maybe we need to dedup
 372            if len(deduped) < len(flattened):
 373                expression = result_func(*deduped.values(), copy=False)
 374
 375    return expression
 376
 377
 378def absorb_and_eliminate(expression, root=True):
 379    """
 380    absorption:
 381        A AND (A OR B) -> A
 382        A OR (A AND B) -> A
 383        A AND (NOT A OR B) -> A AND B
 384        A OR (NOT A AND B) -> A OR B
 385    elimination:
 386        (A AND B) OR (A AND NOT B) -> A
 387        (A OR B) AND (A OR NOT B) -> A
 388    """
 389    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
 390        kind = exp.Or if isinstance(expression, exp.And) else exp.And
 391
 392        for a, b in itertools.permutations(expression.flatten(), 2):
 393            if isinstance(a, kind):
 394                aa, ab = a.unnest_operands()
 395
 396                # absorb
 397                if is_complement(b, aa):
 398                    aa.replace(exp.true() if kind == exp.And else exp.false())
 399                elif is_complement(b, ab):
 400                    ab.replace(exp.true() if kind == exp.And else exp.false())
 401                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
 402                    a.replace(exp.false() if kind == exp.And else exp.true())
 403                elif isinstance(b, kind):
 404                    # eliminate
 405                    rhs = b.unnest_operands()
 406                    ba, bb = rhs
 407
 408                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
 409                        a.replace(aa)
 410                        b.replace(aa)
 411                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
 412                        a.replace(ab)
 413                        b.replace(ab)
 414
 415    return expression
 416
 417
 418def propagate_constants(expression, root=True):
 419    """
 420    Propagate constants for conjunctions in DNF:
 421
 422    SELECT * FROM t WHERE a = b AND b = 5 becomes
 423    SELECT * FROM t WHERE a = 5 AND b = 5
 424
 425    Reference: https://www.sqlite.org/optoverview.html
 426    """
 427
 428    if (
 429        isinstance(expression, exp.And)
 430        and (root or not expression.same_parent)
 431        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
 432    ):
 433        constant_mapping = {}
 434        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
 435            if isinstance(expr, exp.EQ):
 436                l, r = expr.left, expr.right
 437
 438                # TODO: create a helper that can be used to detect nested literal expressions such
 439                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
 440                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
 441                    constant_mapping[l] = (id(l), r)
 442
 443        if constant_mapping:
 444            for column in find_all_in_scope(expression, exp.Column):
 445                parent = column.parent
 446                column_id, constant = constant_mapping.get(column) or (None, None)
 447                if (
 448                    column_id is not None
 449                    and id(column) != column_id
 450                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
 451                ):
 452                    column.replace(constant.copy())
 453
 454    return expression
 455
 456
 457INVERSE_DATE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 458    exp.DateAdd: exp.Sub,
 459    exp.DateSub: exp.Add,
 460    exp.DatetimeAdd: exp.Sub,
 461    exp.DatetimeSub: exp.Add,
 462}
 463
 464INVERSE_OPS: t.Dict[t.Type[exp.Expression], t.Type[exp.Expression]] = {
 465    **INVERSE_DATE_OPS,
 466    exp.Add: exp.Sub,
 467    exp.Sub: exp.Add,
 468}
 469
 470
 471def _is_number(expression: exp.Expression) -> bool:
 472    return expression.is_number
 473
 474
 475def _is_interval(expression: exp.Expression) -> bool:
 476    return isinstance(expression, exp.Interval) and extract_interval(expression) is not None
 477
 478
 479@catch(ModuleNotFoundError, UnsupportedUnit)
 480def simplify_equality(expression: exp.Expression) -> exp.Expression:
 481    """
 482    Use the subtraction and addition properties of equality to simplify expressions:
 483
 484        x + 1 = 3 becomes x = 2
 485
 486    There are two binary operations in the above expression: + and =
 487    Here's how we reference all the operands in the code below:
 488
 489          l     r
 490        x + 1 = 3
 491        a   b
 492    """
 493    if isinstance(expression, COMPARISONS):
 494        l, r = expression.left, expression.right
 495
 496        if l.__class__ not in INVERSE_OPS:
 497            return expression
 498
 499        if r.is_number:
 500            a_predicate = _is_number
 501            b_predicate = _is_number
 502        elif _is_date_literal(r):
 503            a_predicate = _is_date_literal
 504            b_predicate = _is_interval
 505        else:
 506            return expression
 507
 508        if l.__class__ in INVERSE_DATE_OPS:
 509            l = t.cast(exp.IntervalOp, l)
 510            a = l.this
 511            b = l.interval()
 512        else:
 513            l = t.cast(exp.Binary, l)
 514            a, b = l.left, l.right
 515
 516        if not a_predicate(a) and b_predicate(b):
 517            pass
 518        elif not a_predicate(b) and b_predicate(a):
 519            a, b = b, a
 520        else:
 521            return expression
 522
 523        return expression.__class__(
 524            this=a, expression=INVERSE_OPS[l.__class__](this=r, expression=b)
 525        )
 526    return expression
 527
 528
 529def simplify_literals(expression, root=True):
 530    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
 531        return _flat_simplify(expression, _simplify_binary, root)
 532
 533    if isinstance(expression, exp.Neg):
 534        this = expression.this
 535        if this.is_number:
 536            value = this.name
 537            if value[0] == "-":
 538                return exp.Literal.number(value[1:])
 539            return exp.Literal.number(f"-{value}")
 540
 541    if type(expression) in INVERSE_DATE_OPS:
 542        return _simplify_binary(expression, expression.this, expression.interval()) or expression
 543
 544    return expression
 545
 546
 547NULL_OK = (exp.NullSafeEQ, exp.NullSafeNEQ, exp.PropertyEQ)
 548
 549
 550def _simplify_binary(expression, a, b):
 551    if isinstance(expression, exp.Is):
 552        if isinstance(b, exp.Not):
 553            c = b.this
 554            not_ = True
 555        else:
 556            c = b
 557            not_ = False
 558
 559        if is_null(c):
 560            if isinstance(a, exp.Literal):
 561                return exp.true() if not_ else exp.false()
 562            if is_null(a):
 563                return exp.false() if not_ else exp.true()
 564    elif isinstance(expression, NULL_OK):
 565        return None
 566    elif is_null(a) or is_null(b):
 567        return exp.null()
 568
 569    if a.is_number and b.is_number:
 570        num_a = int(a.name) if a.is_int else Decimal(a.name)
 571        num_b = int(b.name) if b.is_int else Decimal(b.name)
 572
 573        if isinstance(expression, exp.Add):
 574            return exp.Literal.number(num_a + num_b)
 575        if isinstance(expression, exp.Mul):
 576            return exp.Literal.number(num_a * num_b)
 577
 578        # We only simplify Sub, Div if a and b have the same parent because they're not associative
 579        if isinstance(expression, exp.Sub):
 580            return exp.Literal.number(num_a - num_b) if a.parent is b.parent else None
 581        if isinstance(expression, exp.Div):
 582            # engines have differing int div behavior so intdiv is not safe
 583            if (isinstance(num_a, int) and isinstance(num_b, int)) or a.parent is not b.parent:
 584                return None
 585            return exp.Literal.number(num_a / num_b)
 586
 587        boolean = eval_boolean(expression, num_a, num_b)
 588
 589        if boolean:
 590            return boolean
 591    elif a.is_string and b.is_string:
 592        boolean = eval_boolean(expression, a.this, b.this)
 593
 594        if boolean:
 595            return boolean
 596    elif _is_date_literal(a) and isinstance(b, exp.Interval):
 597        a, b = extract_date(a), extract_interval(b)
 598        if a and b:
 599            if isinstance(expression, (exp.Add, exp.DateAdd, exp.DatetimeAdd)):
 600                return date_literal(a + b)
 601            if isinstance(expression, (exp.Sub, exp.DateSub, exp.DatetimeSub)):
 602                return date_literal(a - b)
 603    elif isinstance(a, exp.Interval) and _is_date_literal(b):
 604        a, b = extract_interval(a), extract_date(b)
 605        # you cannot subtract a date from an interval
 606        if a and b and isinstance(expression, exp.Add):
 607            return date_literal(a + b)
 608    elif _is_date_literal(a) and _is_date_literal(b):
 609        if isinstance(expression, exp.Predicate):
 610            a, b = extract_date(a), extract_date(b)
 611            boolean = eval_boolean(expression, a, b)
 612            if boolean:
 613                return boolean
 614
 615    return None
 616
 617
 618def simplify_parens(expression):
 619    if not isinstance(expression, exp.Paren):
 620        return expression
 621
 622    this = expression.this
 623    parent = expression.parent
 624    parent_is_predicate = isinstance(parent, exp.Predicate)
 625
 626    if not isinstance(this, exp.Select) and (
 627        not isinstance(parent, (exp.Condition, exp.Binary))
 628        or isinstance(parent, exp.Paren)
 629        or (
 630            not isinstance(this, exp.Binary)
 631            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
 632        )
 633        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
 634        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
 635        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
 636        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
 637    ):
 638        return this
 639    return expression
 640
 641
 642NONNULL_CONSTANTS = (
 643    exp.Literal,
 644    exp.Boolean,
 645)
 646
 647CONSTANTS = (
 648    exp.Literal,
 649    exp.Boolean,
 650    exp.Null,
 651)
 652
 653
 654def _is_nonnull_constant(expression: exp.Expression) -> bool:
 655    return isinstance(expression, NONNULL_CONSTANTS) or _is_date_literal(expression)
 656
 657
 658def _is_constant(expression: exp.Expression) -> bool:
 659    return isinstance(expression, CONSTANTS) or _is_date_literal(expression)
 660
 661
 662def simplify_coalesce(expression):
 663    # COALESCE(x) -> x
 664    if (
 665        isinstance(expression, exp.Coalesce)
 666        and (not expression.expressions or _is_nonnull_constant(expression.this))
 667        # COALESCE is also used as a Spark partitioning hint
 668        and not isinstance(expression.parent, exp.Hint)
 669    ):
 670        return expression.this
 671
 672    if not isinstance(expression, COMPARISONS):
 673        return expression
 674
 675    if isinstance(expression.left, exp.Coalesce):
 676        coalesce = expression.left
 677        other = expression.right
 678    elif isinstance(expression.right, exp.Coalesce):
 679        coalesce = expression.right
 680        other = expression.left
 681    else:
 682        return expression
 683
 684    # This transformation is valid for non-constants,
 685    # but it really only does anything if they are both constants.
 686    if not _is_constant(other):
 687        return expression
 688
 689    # Find the first constant arg
 690    for arg_index, arg in enumerate(coalesce.expressions):
 691        if _is_constant(arg):
 692            break
 693    else:
 694        return expression
 695
 696    coalesce.set("expressions", coalesce.expressions[:arg_index])
 697
 698    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
 699    # since we already remove COALESCE at the top of this function.
 700    coalesce = coalesce if coalesce.expressions else coalesce.this
 701
 702    # This expression is more complex than when we started, but it will get simplified further
 703    return exp.paren(
 704        exp.or_(
 705            exp.and_(
 706                coalesce.is_(exp.null()).not_(copy=False),
 707                expression.copy(),
 708                copy=False,
 709            ),
 710            exp.and_(
 711                coalesce.is_(exp.null()),
 712                type(expression)(this=arg.copy(), expression=other.copy()),
 713                copy=False,
 714            ),
 715            copy=False,
 716        )
 717    )
 718
 719
 720CONCATS = (exp.Concat, exp.DPipe)
 721
 722
 723def simplify_concat(expression):
 724    """Reduces all groups that contain string literals by concatenating them."""
 725    if not isinstance(expression, CONCATS) or (
 726        # We can't reduce a CONCAT_WS call if we don't statically know the separator
 727        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
 728    ):
 729        return expression
 730
 731    if isinstance(expression, exp.ConcatWs):
 732        sep_expr, *expressions = expression.expressions
 733        sep = sep_expr.name
 734        concat_type = exp.ConcatWs
 735        args = {}
 736    else:
 737        expressions = expression.expressions
 738        sep = ""
 739        concat_type = exp.Concat
 740        args = {
 741            "safe": expression.args.get("safe"),
 742            "coalesce": expression.args.get("coalesce"),
 743        }
 744
 745    new_args = []
 746    for is_string_group, group in itertools.groupby(
 747        expressions or expression.flatten(), lambda e: e.is_string
 748    ):
 749        if is_string_group:
 750            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
 751        else:
 752            new_args.extend(group)
 753
 754    if len(new_args) == 1 and new_args[0].is_string:
 755        return new_args[0]
 756
 757    if concat_type is exp.ConcatWs:
 758        new_args = [sep_expr] + new_args
 759
 760    return concat_type(expressions=new_args, **args)
 761
 762
 763def simplify_conditionals(expression):
 764    """Simplifies expressions like IF, CASE if their condition is statically known."""
 765    if isinstance(expression, exp.Case):
 766        this = expression.this
 767        for case in expression.args["ifs"]:
 768            cond = case.this
 769            if this:
 770                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
 771                cond = cond.replace(this.pop().eq(cond))
 772
 773            if always_true(cond):
 774                return case.args["true"]
 775
 776            if always_false(cond):
 777                case.pop()
 778                if not expression.args["ifs"]:
 779                    return expression.args.get("default") or exp.null()
 780    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
 781        if always_true(expression.this):
 782            return expression.args["true"]
 783        if always_false(expression.this):
 784            return expression.args.get("false") or exp.null()
 785
 786    return expression
 787
 788
 789def simplify_startswith(expression: exp.Expression) -> exp.Expression:
 790    """
 791    Reduces a prefix check to either TRUE or FALSE if both the string and the
 792    prefix are statically known.
 793
 794    Example:
 795        >>> from sqlglot import parse_one
 796        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
 797        'TRUE'
 798    """
 799    if (
 800        isinstance(expression, exp.StartsWith)
 801        and expression.this.is_string
 802        and expression.expression.is_string
 803    ):
 804        return exp.convert(expression.name.startswith(expression.expression.name))
 805
 806    return expression
 807
 808
 809DateRange = t.Tuple[datetime.date, datetime.date]
 810
 811
 812def _datetrunc_range(date: datetime.date, unit: str, dialect: Dialect) -> t.Optional[DateRange]:
 813    """
 814    Get the date range for a DATE_TRUNC equality comparison:
 815
 816    Example:
 817        _datetrunc_range(date(2021-01-01), 'year') == (date(2021-01-01), date(2022-01-01))
 818    Returns:
 819        tuple of [min, max) or None if a value can never be equal to `date` for `unit`
 820    """
 821    floor = date_floor(date, unit, dialect)
 822
 823    if date != floor:
 824        # This will always be False, except for NULL values.
 825        return None
 826
 827    return floor, floor + interval(unit)
 828
 829
 830def _datetrunc_eq_expression(left: exp.Expression, drange: DateRange) -> exp.Expression:
 831    """Get the logical expression for a date range"""
 832    return exp.and_(
 833        left >= date_literal(drange[0]),
 834        left < date_literal(drange[1]),
 835        copy=False,
 836    )
 837
 838
 839def _datetrunc_eq(
 840    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 841) -> t.Optional[exp.Expression]:
 842    drange = _datetrunc_range(date, unit, dialect)
 843    if not drange:
 844        return None
 845
 846    return _datetrunc_eq_expression(left, drange)
 847
 848
 849def _datetrunc_neq(
 850    left: exp.Expression, date: datetime.date, unit: str, dialect: Dialect
 851) -> t.Optional[exp.Expression]:
 852    drange = _datetrunc_range(date, unit, dialect)
 853    if not drange:
 854        return None
 855
 856    return exp.and_(
 857        left < date_literal(drange[0]),
 858        left >= date_literal(drange[1]),
 859        copy=False,
 860    )
 861
 862
 863DATETRUNC_BINARY_COMPARISONS: t.Dict[t.Type[exp.Expression], DateTruncBinaryTransform] = {
 864    exp.LT: lambda l, dt, u, d: l
 865    < date_literal(dt if dt == date_floor(dt, u, d) else date_floor(dt, u, d) + interval(u)),
 866    exp.GT: lambda l, dt, u, d: l >= date_literal(date_floor(dt, u, d) + interval(u)),
 867    exp.LTE: lambda l, dt, u, d: l < date_literal(date_floor(dt, u, d) + interval(u)),
 868    exp.GTE: lambda l, dt, u, d: l >= date_literal(date_ceil(dt, u, d)),
 869    exp.EQ: _datetrunc_eq,
 870    exp.NEQ: _datetrunc_neq,
 871}
 872DATETRUNC_COMPARISONS = {exp.In, *DATETRUNC_BINARY_COMPARISONS}
 873DATETRUNCS = (exp.DateTrunc, exp.TimestampTrunc)
 874
 875
 876def _is_datetrunc_predicate(left: exp.Expression, right: exp.Expression) -> bool:
 877    return isinstance(left, DATETRUNCS) and _is_date_literal(right)
 878
 879
 880@catch(ModuleNotFoundError, UnsupportedUnit)
 881def simplify_datetrunc(expression: exp.Expression, dialect: Dialect) -> exp.Expression:
 882    """Simplify expressions like `DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)`"""
 883    comparison = expression.__class__
 884
 885    if isinstance(expression, DATETRUNCS):
 886        date = extract_date(expression.this)
 887        if date and expression.unit:
 888            return date_literal(date_floor(date, expression.unit.name.lower(), dialect))
 889    elif comparison not in DATETRUNC_COMPARISONS:
 890        return expression
 891
 892    if isinstance(expression, exp.Binary):
 893        l, r = expression.left, expression.right
 894
 895        if not _is_datetrunc_predicate(l, r):
 896            return expression
 897
 898        l = t.cast(exp.DateTrunc, l)
 899        unit = l.unit.name.lower()
 900        date = extract_date(r)
 901
 902        if not date:
 903            return expression
 904
 905        return DATETRUNC_BINARY_COMPARISONS[comparison](l.this, date, unit, dialect) or expression
 906    elif isinstance(expression, exp.In):
 907        l = expression.this
 908        rs = expression.expressions
 909
 910        if rs and all(_is_datetrunc_predicate(l, r) for r in rs):
 911            l = t.cast(exp.DateTrunc, l)
 912            unit = l.unit.name.lower()
 913
 914            ranges = []
 915            for r in rs:
 916                date = extract_date(r)
 917                if not date:
 918                    return expression
 919                drange = _datetrunc_range(date, unit, dialect)
 920                if drange:
 921                    ranges.append(drange)
 922
 923            if not ranges:
 924                return expression
 925
 926            ranges = merge_ranges(ranges)
 927
 928            return exp.or_(*[_datetrunc_eq_expression(l, drange) for drange in ranges], copy=False)
 929
 930    return expression
 931
 932
 933def sort_comparison(expression: exp.Expression) -> exp.Expression:
 934    if expression.__class__ in COMPLEMENT_COMPARISONS:
 935        l, r = expression.this, expression.expression
 936        l_column = isinstance(l, exp.Column)
 937        r_column = isinstance(r, exp.Column)
 938        l_const = _is_constant(l)
 939        r_const = _is_constant(r)
 940
 941        if (l_column and not r_column) or (r_const and not l_const):
 942            return expression
 943        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
 944            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
 945                this=r, expression=l
 946            )
 947    return expression
 948
 949
 950# CROSS joins result in an empty table if the right table is empty.
 951# So we can only simplify certain types of joins to CROSS.
 952# Or in other words, LEFT JOIN x ON TRUE != CROSS JOIN x
 953JOINS = {
 954    ("", ""),
 955    ("", "INNER"),
 956    ("RIGHT", ""),
 957    ("RIGHT", "OUTER"),
 958}
 959
 960
 961def remove_where_true(expression):
 962    for where in expression.find_all(exp.Where):
 963        if always_true(where.this):
 964            where.pop()
 965    for join in expression.find_all(exp.Join):
 966        if (
 967            always_true(join.args.get("on"))
 968            and not join.args.get("using")
 969            and not join.args.get("method")
 970            and (join.side, join.kind) in JOINS
 971        ):
 972            join.args["on"].pop()
 973            join.set("side", None)
 974            join.set("kind", "CROSS")
 975
 976
 977def always_true(expression):
 978    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
 979        expression, exp.Literal
 980    )
 981
 982
 983def always_false(expression):
 984    return is_false(expression) or is_null(expression)
 985
 986
 987def is_complement(a, b):
 988    return isinstance(b, exp.Not) and b.this == a
 989
 990
 991def is_false(a: exp.Expression) -> bool:
 992    return type(a) is exp.Boolean and not a.this
 993
 994
 995def is_null(a: exp.Expression) -> bool:
 996    return type(a) is exp.Null
 997
 998
 999def eval_boolean(expression, a, b):
1000    if isinstance(expression, (exp.EQ, exp.Is)):
1001        return boolean_literal(a == b)
1002    if isinstance(expression, exp.NEQ):
1003        return boolean_literal(a != b)
1004    if isinstance(expression, exp.GT):
1005        return boolean_literal(a > b)
1006    if isinstance(expression, exp.GTE):
1007        return boolean_literal(a >= b)
1008    if isinstance(expression, exp.LT):
1009        return boolean_literal(a < b)
1010    if isinstance(expression, exp.LTE):
1011        return boolean_literal(a <= b)
1012    return None
1013
1014
1015def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1016    if isinstance(value, datetime.datetime):
1017        return value.date()
1018    if isinstance(value, datetime.date):
1019        return value
1020    try:
1021        return datetime.datetime.fromisoformat(value).date()
1022    except ValueError:
1023        return None
1024
1025
1026def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1027    if isinstance(value, datetime.datetime):
1028        return value
1029    if isinstance(value, datetime.date):
1030        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1031    try:
1032        return datetime.datetime.fromisoformat(value)
1033    except ValueError:
1034        return None
1035
1036
1037def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1038    if not value:
1039        return None
1040    if to.is_type(exp.DataType.Type.DATE):
1041        return cast_as_date(value)
1042    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1043        return cast_as_datetime(value)
1044    return None
1045
1046
1047def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1048    if isinstance(cast, exp.Cast):
1049        to = cast.to
1050    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1051        to = exp.DataType.build(exp.DataType.Type.DATE)
1052    else:
1053        return None
1054
1055    if isinstance(cast.this, exp.Literal):
1056        value: t.Any = cast.this.name
1057    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1058        value = extract_date(cast.this)
1059    else:
1060        return None
1061    return cast_value(value, to)
1062
1063
1064def _is_date_literal(expression: exp.Expression) -> bool:
1065    return extract_date(expression) is not None
1066
1067
1068def extract_interval(expression):
1069    try:
1070        n = int(expression.name)
1071        unit = expression.text("unit").lower()
1072        return interval(unit, n)
1073    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1074        return None
1075
1076
1077def date_literal(date):
1078    return exp.cast(
1079        exp.Literal.string(date),
1080        (
1081            exp.DataType.Type.DATETIME
1082            if isinstance(date, datetime.datetime)
1083            else exp.DataType.Type.DATE
1084        ),
1085    )
1086
1087
1088def interval(unit: str, n: int = 1):
1089    from dateutil.relativedelta import relativedelta
1090
1091    if unit == "year":
1092        return relativedelta(years=1 * n)
1093    if unit == "quarter":
1094        return relativedelta(months=3 * n)
1095    if unit == "month":
1096        return relativedelta(months=1 * n)
1097    if unit == "week":
1098        return relativedelta(weeks=1 * n)
1099    if unit == "day":
1100        return relativedelta(days=1 * n)
1101    if unit == "hour":
1102        return relativedelta(hours=1 * n)
1103    if unit == "minute":
1104        return relativedelta(minutes=1 * n)
1105    if unit == "second":
1106        return relativedelta(seconds=1 * n)
1107
1108    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1109
1110
1111def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1112    if unit == "year":
1113        return d.replace(month=1, day=1)
1114    if unit == "quarter":
1115        if d.month <= 3:
1116            return d.replace(month=1, day=1)
1117        elif d.month <= 6:
1118            return d.replace(month=4, day=1)
1119        elif d.month <= 9:
1120            return d.replace(month=7, day=1)
1121        else:
1122            return d.replace(month=10, day=1)
1123    if unit == "month":
1124        return d.replace(month=d.month, day=1)
1125    if unit == "week":
1126        # Assuming week starts on Monday (0) and ends on Sunday (6)
1127        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1128    if unit == "day":
1129        return d
1130
1131    raise UnsupportedUnit(f"Unsupported unit: {unit}")
1132
1133
1134def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1135    floor = date_floor(d, unit, dialect)
1136
1137    if floor == d:
1138        return d
1139
1140    return floor + interval(unit)
1141
1142
1143def boolean_literal(condition):
1144    return exp.true() if condition else exp.false()
1145
1146
1147def _flat_simplify(expression, simplifier, root=True):
1148    if root or not expression.same_parent:
1149        operands = []
1150        queue = deque(expression.flatten(unnest=False))
1151        size = len(queue)
1152
1153        while queue:
1154            a = queue.popleft()
1155
1156            for b in queue:
1157                result = simplifier(expression, a, b)
1158
1159                if result and result is not expression:
1160                    queue.remove(b)
1161                    queue.appendleft(result)
1162                    break
1163            else:
1164                operands.append(a)
1165
1166        if len(operands) < size:
1167            return functools.reduce(
1168                lambda a, b: expression.__class__(this=a, expression=b), operands
1169            )
1170    return expression
1171
1172
1173def gen(expression: t.Any) -> str:
1174    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1175
1176    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1177    generator is expensive so we have a bare minimum sql generator here.
1178    """
1179    return Gen().gen(expression)
1180
1181
1182class Gen:
1183    def __init__(self):
1184        self.stack = []
1185        self.sqls = []
1186
1187    def gen(self, expression: exp.Expression) -> str:
1188        self.stack = [expression]
1189        self.sqls.clear()
1190
1191        while self.stack:
1192            node = self.stack.pop()
1193
1194            if isinstance(node, exp.Expression):
1195                exp_handler_name = f"{node.key}_sql"
1196
1197                if hasattr(self, exp_handler_name):
1198                    getattr(self, exp_handler_name)(node)
1199                elif isinstance(node, exp.Func):
1200                    self._function(node)
1201                else:
1202                    key = node.key.upper()
1203                    self.stack.append(f"{key} " if self._args(node) else key)
1204            elif type(node) is list:
1205                for n in reversed(node):
1206                    if n is not None:
1207                        self.stack.extend((n, ","))
1208                if node:
1209                    self.stack.pop()
1210            else:
1211                if node is not None:
1212                    self.sqls.append(str(node))
1213
1214        return "".join(self.sqls)
1215
1216    def add_sql(self, e: exp.Add) -> None:
1217        self._binary(e, " + ")
1218
1219    def alias_sql(self, e: exp.Alias) -> None:
1220        self.stack.extend(
1221            (
1222                e.args.get("alias"),
1223                " AS ",
1224                e.args.get("this"),
1225            )
1226        )
1227
1228    def and_sql(self, e: exp.And) -> None:
1229        self._binary(e, " AND ")
1230
1231    def anonymous_sql(self, e: exp.Anonymous) -> None:
1232        this = e.this
1233        if isinstance(this, str):
1234            name = this.upper()
1235        elif isinstance(this, exp.Identifier):
1236            name = this.this
1237            name = f'"{name}"' if this.quoted else name.upper()
1238        else:
1239            raise ValueError(
1240                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1241            )
1242
1243        self.stack.extend(
1244            (
1245                ")",
1246                e.expressions,
1247                "(",
1248                name,
1249            )
1250        )
1251
1252    def between_sql(self, e: exp.Between) -> None:
1253        self.stack.extend(
1254            (
1255                e.args.get("high"),
1256                " AND ",
1257                e.args.get("low"),
1258                " BETWEEN ",
1259                e.this,
1260            )
1261        )
1262
1263    def boolean_sql(self, e: exp.Boolean) -> None:
1264        self.stack.append("TRUE" if e.this else "FALSE")
1265
1266    def bracket_sql(self, e: exp.Bracket) -> None:
1267        self.stack.extend(
1268            (
1269                "]",
1270                e.expressions,
1271                "[",
1272                e.this,
1273            )
1274        )
1275
1276    def column_sql(self, e: exp.Column) -> None:
1277        for p in reversed(e.parts):
1278            self.stack.extend((p, "."))
1279        self.stack.pop()
1280
1281    def datatype_sql(self, e: exp.DataType) -> None:
1282        self._args(e, 1)
1283        self.stack.append(f"{e.this.name} ")
1284
1285    def div_sql(self, e: exp.Div) -> None:
1286        self._binary(e, " / ")
1287
1288    def dot_sql(self, e: exp.Dot) -> None:
1289        self._binary(e, ".")
1290
1291    def eq_sql(self, e: exp.EQ) -> None:
1292        self._binary(e, " = ")
1293
1294    def from_sql(self, e: exp.From) -> None:
1295        self.stack.extend((e.this, "FROM "))
1296
1297    def gt_sql(self, e: exp.GT) -> None:
1298        self._binary(e, " > ")
1299
1300    def gte_sql(self, e: exp.GTE) -> None:
1301        self._binary(e, " >= ")
1302
1303    def identifier_sql(self, e: exp.Identifier) -> None:
1304        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1305
1306    def ilike_sql(self, e: exp.ILike) -> None:
1307        self._binary(e, " ILIKE ")
1308
1309    def in_sql(self, e: exp.In) -> None:
1310        self.stack.append(")")
1311        self._args(e, 1)
1312        self.stack.extend(
1313            (
1314                "(",
1315                " IN ",
1316                e.this,
1317            )
1318        )
1319
1320    def intdiv_sql(self, e: exp.IntDiv) -> None:
1321        self._binary(e, " DIV ")
1322
1323    def is_sql(self, e: exp.Is) -> None:
1324        self._binary(e, " IS ")
1325
1326    def like_sql(self, e: exp.Like) -> None:
1327        self._binary(e, " Like ")
1328
1329    def literal_sql(self, e: exp.Literal) -> None:
1330        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1331
1332    def lt_sql(self, e: exp.LT) -> None:
1333        self._binary(e, " < ")
1334
1335    def lte_sql(self, e: exp.LTE) -> None:
1336        self._binary(e, " <= ")
1337
1338    def mod_sql(self, e: exp.Mod) -> None:
1339        self._binary(e, " % ")
1340
1341    def mul_sql(self, e: exp.Mul) -> None:
1342        self._binary(e, " * ")
1343
1344    def neg_sql(self, e: exp.Neg) -> None:
1345        self._unary(e, "-")
1346
1347    def neq_sql(self, e: exp.NEQ) -> None:
1348        self._binary(e, " <> ")
1349
1350    def not_sql(self, e: exp.Not) -> None:
1351        self._unary(e, "NOT ")
1352
1353    def null_sql(self, e: exp.Null) -> None:
1354        self.stack.append("NULL")
1355
1356    def or_sql(self, e: exp.Or) -> None:
1357        self._binary(e, " OR ")
1358
1359    def paren_sql(self, e: exp.Paren) -> None:
1360        self.stack.extend(
1361            (
1362                ")",
1363                e.this,
1364                "(",
1365            )
1366        )
1367
1368    def sub_sql(self, e: exp.Sub) -> None:
1369        self._binary(e, " - ")
1370
1371    def subquery_sql(self, e: exp.Subquery) -> None:
1372        self._args(e, 2)
1373        alias = e.args.get("alias")
1374        if alias:
1375            self.stack.append(alias)
1376        self.stack.extend((")", e.this, "("))
1377
1378    def table_sql(self, e: exp.Table) -> None:
1379        self._args(e, 4)
1380        alias = e.args.get("alias")
1381        if alias:
1382            self.stack.append(alias)
1383        for p in reversed(e.parts):
1384            self.stack.extend((p, "."))
1385        self.stack.pop()
1386
1387    def tablealias_sql(self, e: exp.TableAlias) -> None:
1388        columns = e.columns
1389
1390        if columns:
1391            self.stack.extend((")", columns, "("))
1392
1393        self.stack.extend((e.this, " AS "))
1394
1395    def var_sql(self, e: exp.Var) -> None:
1396        self.stack.append(e.this)
1397
1398    def _binary(self, e: exp.Binary, op: str) -> None:
1399        self.stack.extend((e.expression, op, e.this))
1400
1401    def _unary(self, e: exp.Unary, op: str) -> None:
1402        self.stack.extend((e.this, op))
1403
1404    def _function(self, e: exp.Func) -> None:
1405        self.stack.extend(
1406            (
1407                ")",
1408                list(e.args.values()),
1409                "(",
1410                e.sql_name(),
1411            )
1412        )
1413
1414    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1415        kvs = []
1416        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1417
1418        for k in arg_types or arg_types:
1419            v = node.args.get(k)
1420
1421            if v is not None:
1422                kvs.append([f":{k}", v])
1423        if kvs:
1424            self.stack.append(kvs)
1425            return True
1426        return False
FINAL = 'final'
class UnsupportedUnit(builtins.Exception):
27class UnsupportedUnit(Exception):
28    pass

Common base class for all non-exit exceptions.

Inherited Members
builtins.Exception
Exception
builtins.BaseException
with_traceback
args
def simplify( expression: sqlglot.expressions.Expression, constant_propagation: bool = False, dialect: Union[str, sqlglot.dialects.dialect.Dialect, Type[sqlglot.dialects.dialect.Dialect], NoneType] = None):
 31def simplify(
 32    expression: exp.Expression, constant_propagation: bool = False, dialect: DialectType = None
 33):
 34    """
 35    Rewrite sqlglot AST to simplify expressions.
 36
 37    Example:
 38        >>> import sqlglot
 39        >>> expression = sqlglot.parse_one("TRUE AND TRUE")
 40        >>> simplify(expression).sql()
 41        'TRUE'
 42
 43    Args:
 44        expression (sqlglot.Expression): expression to simplify
 45        constant_propagation: whether the constant propagation rule should be used
 46
 47    Returns:
 48        sqlglot.Expression: simplified expression
 49    """
 50
 51    dialect = Dialect.get_or_raise(dialect)
 52
 53    def _simplify(expression, root=True):
 54        if expression.meta.get(FINAL):
 55            return expression
 56
 57        # group by expressions cannot be simplified, for example
 58        # select x + 1 + 1 FROM y GROUP BY x + 1 + 1
 59        # the projection must exactly match the group by key
 60        group = expression.args.get("group")
 61
 62        if group and hasattr(expression, "selects"):
 63            groups = set(group.expressions)
 64            group.meta[FINAL] = True
 65
 66            for e in expression.selects:
 67                for node in e.walk():
 68                    if node in groups:
 69                        e.meta[FINAL] = True
 70                        break
 71
 72            having = expression.args.get("having")
 73            if having:
 74                for node in having.walk():
 75                    if node in groups:
 76                        having.meta[FINAL] = True
 77                        break
 78
 79        # Pre-order transformations
 80        node = expression
 81        node = rewrite_between(node)
 82        node = uniq_sort(node, root)
 83        node = absorb_and_eliminate(node, root)
 84        node = simplify_concat(node)
 85        node = simplify_conditionals(node)
 86
 87        if constant_propagation:
 88            node = propagate_constants(node, root)
 89
 90        exp.replace_children(node, lambda e: _simplify(e, False))
 91
 92        # Post-order transformations
 93        node = simplify_not(node)
 94        node = flatten(node)
 95        node = simplify_connectors(node, root)
 96        node = remove_complements(node, root)
 97        node = simplify_coalesce(node)
 98        node.parent = expression.parent
 99        node = simplify_literals(node, root)
100        node = simplify_equality(node)
101        node = simplify_parens(node)
102        node = simplify_datetrunc(node, dialect)
103        node = sort_comparison(node)
104        node = simplify_startswith(node)
105
106        if root:
107            expression.replace(node)
108        return node
109
110    expression = while_changing(expression, _simplify)
111    remove_where_true(expression)
112    return expression

Rewrite sqlglot AST to simplify expressions.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("TRUE AND TRUE")
>>> simplify(expression).sql()
'TRUE'
Arguments:
  • expression (sqlglot.Expression): expression to simplify
  • constant_propagation: whether the constant propagation rule should be used
Returns:

sqlglot.Expression: simplified expression

def catch(*exceptions):
115def catch(*exceptions):
116    """Decorator that ignores a simplification function if any of `exceptions` are raised"""
117
118    def decorator(func):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression
124
125        return wrapped
126
127    return decorator

Decorator that ignores a simplification function if any of exceptions are raised

def rewrite_between( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
130def rewrite_between(expression: exp.Expression) -> exp.Expression:
131    """Rewrite x between y and z to x >= y AND x <= z.
132
133    This is done because comparison simplification is only done on lt/lte/gt/gte.
134    """
135    if isinstance(expression, exp.Between):
136        negate = isinstance(expression.parent, exp.Not)
137
138        expression = exp.and_(
139            exp.GTE(this=expression.this.copy(), expression=expression.args["low"]),
140            exp.LTE(this=expression.this.copy(), expression=expression.args["high"]),
141            copy=False,
142        )
143
144        if negate:
145            expression = exp.paren(expression, copy=False)
146
147    return expression

Rewrite x between y and z to x >= y AND x <= z.

This is done because comparison simplification is only done on lt/lte/gt/gte.

def simplify_not(expression):
160def simplify_not(expression):
161    """
162    Demorgan's Law
163    NOT (x OR y) -> NOT x AND NOT y
164    NOT (x AND y) -> NOT x OR NOT y
165    """
166    if isinstance(expression, exp.Not):
167        this = expression.this
168        if is_null(this):
169            return exp.null()
170        if this.__class__ in COMPLEMENT_COMPARISONS:
171            return COMPLEMENT_COMPARISONS[this.__class__](
172                this=this.this, expression=this.expression
173            )
174        if isinstance(this, exp.Paren):
175            condition = this.unnest()
176            if isinstance(condition, exp.And):
177                return exp.paren(
178                    exp.or_(
179                        exp.not_(condition.left, copy=False),
180                        exp.not_(condition.right, copy=False),
181                        copy=False,
182                    )
183                )
184            if isinstance(condition, exp.Or):
185                return exp.paren(
186                    exp.and_(
187                        exp.not_(condition.left, copy=False),
188                        exp.not_(condition.right, copy=False),
189                        copy=False,
190                    )
191                )
192            if is_null(condition):
193                return exp.null()
194        if always_true(this):
195            return exp.false()
196        if is_false(this):
197            return exp.true()
198        if isinstance(this, exp.Not):
199            # double negation
200            # NOT NOT x -> x
201            return this.this
202    return expression

Demorgan's Law NOT (x OR y) -> NOT x AND NOT y NOT (x AND y) -> NOT x OR NOT y

def flatten(expression):
205def flatten(expression):
206    """
207    A AND (B AND C) -> A AND B AND C
208    A OR (B OR C) -> A OR B OR C
209    """
210    if isinstance(expression, exp.Connector):
211        for node in expression.args.values():
212            child = node.unnest()
213            if isinstance(child, expression.__class__):
214                node.replace(child)
215    return expression

A AND (B AND C) -> A AND B AND C A OR (B OR C) -> A OR B OR C

def simplify_connectors(expression, root=True):
218def simplify_connectors(expression, root=True):
219    def _simplify_connectors(expression, left, right):
220        if left == right:
221            return left
222        if isinstance(expression, exp.And):
223            if is_false(left) or is_false(right):
224                return exp.false()
225            if is_null(left) or is_null(right):
226                return exp.null()
227            if always_true(left) and always_true(right):
228                return exp.true()
229            if always_true(left):
230                return right
231            if always_true(right):
232                return left
233            return _simplify_comparison(expression, left, right)
234        elif isinstance(expression, exp.Or):
235            if always_true(left) or always_true(right):
236                return exp.true()
237            if is_false(left) and is_false(right):
238                return exp.false()
239            if (
240                (is_null(left) and is_null(right))
241                or (is_null(left) and is_false(right))
242                or (is_false(left) and is_null(right))
243            ):
244                return exp.null()
245            if is_false(left):
246                return right
247            if is_false(right):
248                return left
249            return _simplify_comparison(expression, left, right, or_=True)
250
251    if isinstance(expression, exp.Connector):
252        return _flat_simplify(expression, _simplify_connectors, root)
253    return expression
LT_LTE = (<class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.LTE'>)
GT_GTE = (<class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.GTE'>)
NONDETERMINISTIC = (<class 'sqlglot.expressions.Rand'>, <class 'sqlglot.expressions.Randn'>)
def remove_complements(expression, root=True):
337def remove_complements(expression, root=True):
338    """
339    Removing complements.
340
341    A AND NOT A -> FALSE
342    A OR NOT A -> TRUE
343    """
344    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
345        complement = exp.false() if isinstance(expression, exp.And) else exp.true()
346
347        for a, b in itertools.permutations(expression.flatten(), 2):
348            if is_complement(a, b):
349                return complement
350    return expression

Removing complements.

A AND NOT A -> FALSE A OR NOT A -> TRUE

def uniq_sort(expression, root=True):
353def uniq_sort(expression, root=True):
354    """
355    Uniq and sort a connector.
356
357    C AND A AND B AND B -> A AND B AND C
358    """
359    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
360        result_func = exp.and_ if isinstance(expression, exp.And) else exp.or_
361        flattened = tuple(expression.flatten())
362        deduped = {gen(e): e for e in flattened}
363        arr = tuple(deduped.items())
364
365        # check if the operands are already sorted, if not sort them
366        # A AND C AND B -> A AND B AND C
367        for i, (sql, e) in enumerate(arr[1:]):
368            if sql < arr[i][0]:
369                expression = result_func(*(e for _, e in sorted(arr)), copy=False)
370                break
371        else:
372            # we didn't have to sort but maybe we need to dedup
373            if len(deduped) < len(flattened):
374                expression = result_func(*deduped.values(), copy=False)
375
376    return expression

Uniq and sort a connector.

C AND A AND B AND B -> A AND B AND C

def absorb_and_eliminate(expression, root=True):
379def absorb_and_eliminate(expression, root=True):
380    """
381    absorption:
382        A AND (A OR B) -> A
383        A OR (A AND B) -> A
384        A AND (NOT A OR B) -> A AND B
385        A OR (NOT A AND B) -> A OR B
386    elimination:
387        (A AND B) OR (A AND NOT B) -> A
388        (A OR B) AND (A OR NOT B) -> A
389    """
390    if isinstance(expression, exp.Connector) and (root or not expression.same_parent):
391        kind = exp.Or if isinstance(expression, exp.And) else exp.And
392
393        for a, b in itertools.permutations(expression.flatten(), 2):
394            if isinstance(a, kind):
395                aa, ab = a.unnest_operands()
396
397                # absorb
398                if is_complement(b, aa):
399                    aa.replace(exp.true() if kind == exp.And else exp.false())
400                elif is_complement(b, ab):
401                    ab.replace(exp.true() if kind == exp.And else exp.false())
402                elif (set(b.flatten()) if isinstance(b, kind) else {b}) < set(a.flatten()):
403                    a.replace(exp.false() if kind == exp.And else exp.true())
404                elif isinstance(b, kind):
405                    # eliminate
406                    rhs = b.unnest_operands()
407                    ba, bb = rhs
408
409                    if aa in rhs and (is_complement(ab, ba) or is_complement(ab, bb)):
410                        a.replace(aa)
411                        b.replace(aa)
412                    elif ab in rhs and (is_complement(aa, ba) or is_complement(aa, bb)):
413                        a.replace(ab)
414                        b.replace(ab)
415
416    return expression

absorption: A AND (A OR B) -> A A OR (A AND B) -> A A AND (NOT A OR B) -> A AND B A OR (NOT A AND B) -> A OR B elimination: (A AND B) OR (A AND NOT B) -> A (A OR B) AND (A OR NOT B) -> A

def propagate_constants(expression, root=True):
419def propagate_constants(expression, root=True):
420    """
421    Propagate constants for conjunctions in DNF:
422
423    SELECT * FROM t WHERE a = b AND b = 5 becomes
424    SELECT * FROM t WHERE a = 5 AND b = 5
425
426    Reference: https://www.sqlite.org/optoverview.html
427    """
428
429    if (
430        isinstance(expression, exp.And)
431        and (root or not expression.same_parent)
432        and sqlglot.optimizer.normalize.normalized(expression, dnf=True)
433    ):
434        constant_mapping = {}
435        for expr in walk_in_scope(expression, prune=lambda node: isinstance(node, exp.If)):
436            if isinstance(expr, exp.EQ):
437                l, r = expr.left, expr.right
438
439                # TODO: create a helper that can be used to detect nested literal expressions such
440                # as CAST(123456 AS BIGINT), since we usually want to treat those as literals too
441                if isinstance(l, exp.Column) and isinstance(r, exp.Literal):
442                    constant_mapping[l] = (id(l), r)
443
444        if constant_mapping:
445            for column in find_all_in_scope(expression, exp.Column):
446                parent = column.parent
447                column_id, constant = constant_mapping.get(column) or (None, None)
448                if (
449                    column_id is not None
450                    and id(column) != column_id
451                    and not (isinstance(parent, exp.Is) and isinstance(parent.expression, exp.Null))
452                ):
453                    column.replace(constant.copy())
454
455    return expression

Propagate constants for conjunctions in DNF:

SELECT * FROM t WHERE a = b AND b = 5 becomes SELECT * FROM t WHERE a = 5 AND b = 5

Reference: https://www.sqlite.org/optoverview.html

def simplify_equality(expression, *args, **kwargs):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression
Use the subtraction and addition properties of equality to simplify expressions:

x + 1 = 3 becomes x = 2

There are two binary operations in the above expression: + and = Here's how we reference all the operands in the code below:

  l     r
x + 1 = 3
a   b
def simplify_literals(expression, root=True):
530def simplify_literals(expression, root=True):
531    if isinstance(expression, exp.Binary) and not isinstance(expression, exp.Connector):
532        return _flat_simplify(expression, _simplify_binary, root)
533
534    if isinstance(expression, exp.Neg):
535        this = expression.this
536        if this.is_number:
537            value = this.name
538            if value[0] == "-":
539                return exp.Literal.number(value[1:])
540            return exp.Literal.number(f"-{value}")
541
542    if type(expression) in INVERSE_DATE_OPS:
543        return _simplify_binary(expression, expression.this, expression.interval()) or expression
544
545    return expression
def simplify_parens(expression):
619def simplify_parens(expression):
620    if not isinstance(expression, exp.Paren):
621        return expression
622
623    this = expression.this
624    parent = expression.parent
625    parent_is_predicate = isinstance(parent, exp.Predicate)
626
627    if not isinstance(this, exp.Select) and (
628        not isinstance(parent, (exp.Condition, exp.Binary))
629        or isinstance(parent, exp.Paren)
630        or (
631            not isinstance(this, exp.Binary)
632            and not (isinstance(this, (exp.Not, exp.Is)) and parent_is_predicate)
633        )
634        or (isinstance(this, exp.Predicate) and not parent_is_predicate)
635        or (isinstance(this, exp.Add) and isinstance(parent, exp.Add))
636        or (isinstance(this, exp.Mul) and isinstance(parent, exp.Mul))
637        or (isinstance(this, exp.Mul) and isinstance(parent, (exp.Add, exp.Sub)))
638    ):
639        return this
640    return expression
NONNULL_CONSTANTS = (<class 'sqlglot.expressions.Literal'>, <class 'sqlglot.expressions.Boolean'>)
def simplify_coalesce(expression):
663def simplify_coalesce(expression):
664    # COALESCE(x) -> x
665    if (
666        isinstance(expression, exp.Coalesce)
667        and (not expression.expressions or _is_nonnull_constant(expression.this))
668        # COALESCE is also used as a Spark partitioning hint
669        and not isinstance(expression.parent, exp.Hint)
670    ):
671        return expression.this
672
673    if not isinstance(expression, COMPARISONS):
674        return expression
675
676    if isinstance(expression.left, exp.Coalesce):
677        coalesce = expression.left
678        other = expression.right
679    elif isinstance(expression.right, exp.Coalesce):
680        coalesce = expression.right
681        other = expression.left
682    else:
683        return expression
684
685    # This transformation is valid for non-constants,
686    # but it really only does anything if they are both constants.
687    if not _is_constant(other):
688        return expression
689
690    # Find the first constant arg
691    for arg_index, arg in enumerate(coalesce.expressions):
692        if _is_constant(arg):
693            break
694    else:
695        return expression
696
697    coalesce.set("expressions", coalesce.expressions[:arg_index])
698
699    # Remove the COALESCE function. This is an optimization, skipping a simplify iteration,
700    # since we already remove COALESCE at the top of this function.
701    coalesce = coalesce if coalesce.expressions else coalesce.this
702
703    # This expression is more complex than when we started, but it will get simplified further
704    return exp.paren(
705        exp.or_(
706            exp.and_(
707                coalesce.is_(exp.null()).not_(copy=False),
708                expression.copy(),
709                copy=False,
710            ),
711            exp.and_(
712                coalesce.is_(exp.null()),
713                type(expression)(this=arg.copy(), expression=other.copy()),
714                copy=False,
715            ),
716            copy=False,
717        )
718    )
CONCATS = (<class 'sqlglot.expressions.Concat'>, <class 'sqlglot.expressions.DPipe'>)
def simplify_concat(expression):
724def simplify_concat(expression):
725    """Reduces all groups that contain string literals by concatenating them."""
726    if not isinstance(expression, CONCATS) or (
727        # We can't reduce a CONCAT_WS call if we don't statically know the separator
728        isinstance(expression, exp.ConcatWs) and not expression.expressions[0].is_string
729    ):
730        return expression
731
732    if isinstance(expression, exp.ConcatWs):
733        sep_expr, *expressions = expression.expressions
734        sep = sep_expr.name
735        concat_type = exp.ConcatWs
736        args = {}
737    else:
738        expressions = expression.expressions
739        sep = ""
740        concat_type = exp.Concat
741        args = {
742            "safe": expression.args.get("safe"),
743            "coalesce": expression.args.get("coalesce"),
744        }
745
746    new_args = []
747    for is_string_group, group in itertools.groupby(
748        expressions or expression.flatten(), lambda e: e.is_string
749    ):
750        if is_string_group:
751            new_args.append(exp.Literal.string(sep.join(string.name for string in group)))
752        else:
753            new_args.extend(group)
754
755    if len(new_args) == 1 and new_args[0].is_string:
756        return new_args[0]
757
758    if concat_type is exp.ConcatWs:
759        new_args = [sep_expr] + new_args
760
761    return concat_type(expressions=new_args, **args)

Reduces all groups that contain string literals by concatenating them.

def simplify_conditionals(expression):
764def simplify_conditionals(expression):
765    """Simplifies expressions like IF, CASE if their condition is statically known."""
766    if isinstance(expression, exp.Case):
767        this = expression.this
768        for case in expression.args["ifs"]:
769            cond = case.this
770            if this:
771                # Convert CASE x WHEN matching_value ... to CASE WHEN x = matching_value ...
772                cond = cond.replace(this.pop().eq(cond))
773
774            if always_true(cond):
775                return case.args["true"]
776
777            if always_false(cond):
778                case.pop()
779                if not expression.args["ifs"]:
780                    return expression.args.get("default") or exp.null()
781    elif isinstance(expression, exp.If) and not isinstance(expression.parent, exp.Case):
782        if always_true(expression.this):
783            return expression.args["true"]
784        if always_false(expression.this):
785            return expression.args.get("false") or exp.null()
786
787    return expression

Simplifies expressions like IF, CASE if their condition is statically known.

def simplify_startswith( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
790def simplify_startswith(expression: exp.Expression) -> exp.Expression:
791    """
792    Reduces a prefix check to either TRUE or FALSE if both the string and the
793    prefix are statically known.
794
795    Example:
796        >>> from sqlglot import parse_one
797        >>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
798        'TRUE'
799    """
800    if (
801        isinstance(expression, exp.StartsWith)
802        and expression.this.is_string
803        and expression.expression.is_string
804    ):
805        return exp.convert(expression.name.startswith(expression.expression.name))
806
807    return expression

Reduces a prefix check to either TRUE or FALSE if both the string and the prefix are statically known.

Example:
>>> from sqlglot import parse_one
>>> simplify_startswith(parse_one("STARTSWITH('foo', 'f')")).sql()
'TRUE'
DateRange = typing.Tuple[datetime.date, datetime.date]
DATETRUNC_BINARY_COMPARISONS: Dict[Type[sqlglot.expressions.Expression], Callable[[sqlglot.expressions.Expression, datetime.date, str, sqlglot.dialects.dialect.Dialect], Optional[sqlglot.expressions.Expression]]] = {<class 'sqlglot.expressions.LT'>: <function <lambda>>, <class 'sqlglot.expressions.GT'>: <function <lambda>>, <class 'sqlglot.expressions.LTE'>: <function <lambda>>, <class 'sqlglot.expressions.GTE'>: <function <lambda>>, <class 'sqlglot.expressions.EQ'>: <function _datetrunc_eq>, <class 'sqlglot.expressions.NEQ'>: <function _datetrunc_neq>}
DATETRUNC_COMPARISONS = {<class 'sqlglot.expressions.LTE'>, <class 'sqlglot.expressions.GT'>, <class 'sqlglot.expressions.EQ'>, <class 'sqlglot.expressions.LT'>, <class 'sqlglot.expressions.NEQ'>, <class 'sqlglot.expressions.GTE'>, <class 'sqlglot.expressions.In'>}
def simplify_datetrunc(expression, *args, **kwargs):
119        def wrapped(expression, *args, **kwargs):
120            try:
121                return func(expression, *args, **kwargs)
122            except exceptions:
123                return expression

Simplify expressions like DATE_TRUNC('year', x) >= CAST('2021-01-01' AS DATE)

def sort_comparison( expression: sqlglot.expressions.Expression) -> sqlglot.expressions.Expression:
934def sort_comparison(expression: exp.Expression) -> exp.Expression:
935    if expression.__class__ in COMPLEMENT_COMPARISONS:
936        l, r = expression.this, expression.expression
937        l_column = isinstance(l, exp.Column)
938        r_column = isinstance(r, exp.Column)
939        l_const = _is_constant(l)
940        r_const = _is_constant(r)
941
942        if (l_column and not r_column) or (r_const and not l_const):
943            return expression
944        if (r_column and not l_column) or (l_const and not r_const) or (gen(l) > gen(r)):
945            return INVERSE_COMPARISONS.get(expression.__class__, expression.__class__)(
946                this=r, expression=l
947            )
948    return expression
JOINS = {('', ''), ('RIGHT', ''), ('RIGHT', 'OUTER'), ('', 'INNER')}
def remove_where_true(expression):
962def remove_where_true(expression):
963    for where in expression.find_all(exp.Where):
964        if always_true(where.this):
965            where.pop()
966    for join in expression.find_all(exp.Join):
967        if (
968            always_true(join.args.get("on"))
969            and not join.args.get("using")
970            and not join.args.get("method")
971            and (join.side, join.kind) in JOINS
972        ):
973            join.args["on"].pop()
974            join.set("side", None)
975            join.set("kind", "CROSS")
def always_true(expression):
978def always_true(expression):
979    return (isinstance(expression, exp.Boolean) and expression.this) or isinstance(
980        expression, exp.Literal
981    )
def always_false(expression):
984def always_false(expression):
985    return is_false(expression) or is_null(expression)
def is_complement(a, b):
988def is_complement(a, b):
989    return isinstance(b, exp.Not) and b.this == a
def is_false(a: sqlglot.expressions.Expression) -> bool:
992def is_false(a: exp.Expression) -> bool:
993    return type(a) is exp.Boolean and not a.this
def is_null(a: sqlglot.expressions.Expression) -> bool:
996def is_null(a: exp.Expression) -> bool:
997    return type(a) is exp.Null
def eval_boolean(expression, a, b):
1000def eval_boolean(expression, a, b):
1001    if isinstance(expression, (exp.EQ, exp.Is)):
1002        return boolean_literal(a == b)
1003    if isinstance(expression, exp.NEQ):
1004        return boolean_literal(a != b)
1005    if isinstance(expression, exp.GT):
1006        return boolean_literal(a > b)
1007    if isinstance(expression, exp.GTE):
1008        return boolean_literal(a >= b)
1009    if isinstance(expression, exp.LT):
1010        return boolean_literal(a < b)
1011    if isinstance(expression, exp.LTE):
1012        return boolean_literal(a <= b)
1013    return None
def cast_as_date(value: Any) -> Optional[datetime.date]:
1016def cast_as_date(value: t.Any) -> t.Optional[datetime.date]:
1017    if isinstance(value, datetime.datetime):
1018        return value.date()
1019    if isinstance(value, datetime.date):
1020        return value
1021    try:
1022        return datetime.datetime.fromisoformat(value).date()
1023    except ValueError:
1024        return None
def cast_as_datetime(value: Any) -> Optional[datetime.datetime]:
1027def cast_as_datetime(value: t.Any) -> t.Optional[datetime.datetime]:
1028    if isinstance(value, datetime.datetime):
1029        return value
1030    if isinstance(value, datetime.date):
1031        return datetime.datetime(year=value.year, month=value.month, day=value.day)
1032    try:
1033        return datetime.datetime.fromisoformat(value)
1034    except ValueError:
1035        return None
def cast_value(value: Any, to: sqlglot.expressions.DataType) -> Optional[datetime.date]:
1038def cast_value(value: t.Any, to: exp.DataType) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1039    if not value:
1040        return None
1041    if to.is_type(exp.DataType.Type.DATE):
1042        return cast_as_date(value)
1043    if to.is_type(*exp.DataType.TEMPORAL_TYPES):
1044        return cast_as_datetime(value)
1045    return None
def extract_date(cast: sqlglot.expressions.Expression) -> Optional[datetime.date]:
1048def extract_date(cast: exp.Expression) -> t.Optional[t.Union[datetime.date, datetime.date]]:
1049    if isinstance(cast, exp.Cast):
1050        to = cast.to
1051    elif isinstance(cast, exp.TsOrDsToDate) and not cast.args.get("format"):
1052        to = exp.DataType.build(exp.DataType.Type.DATE)
1053    else:
1054        return None
1055
1056    if isinstance(cast.this, exp.Literal):
1057        value: t.Any = cast.this.name
1058    elif isinstance(cast.this, (exp.Cast, exp.TsOrDsToDate)):
1059        value = extract_date(cast.this)
1060    else:
1061        return None
1062    return cast_value(value, to)
def extract_interval(expression):
1069def extract_interval(expression):
1070    try:
1071        n = int(expression.name)
1072        unit = expression.text("unit").lower()
1073        return interval(unit, n)
1074    except (UnsupportedUnit, ModuleNotFoundError, ValueError):
1075        return None
def date_literal(date):
1078def date_literal(date):
1079    return exp.cast(
1080        exp.Literal.string(date),
1081        (
1082            exp.DataType.Type.DATETIME
1083            if isinstance(date, datetime.datetime)
1084            else exp.DataType.Type.DATE
1085        ),
1086    )
def interval(unit: str, n: int = 1):
1089def interval(unit: str, n: int = 1):
1090    from dateutil.relativedelta import relativedelta
1091
1092    if unit == "year":
1093        return relativedelta(years=1 * n)
1094    if unit == "quarter":
1095        return relativedelta(months=3 * n)
1096    if unit == "month":
1097        return relativedelta(months=1 * n)
1098    if unit == "week":
1099        return relativedelta(weeks=1 * n)
1100    if unit == "day":
1101        return relativedelta(days=1 * n)
1102    if unit == "hour":
1103        return relativedelta(hours=1 * n)
1104    if unit == "minute":
1105        return relativedelta(minutes=1 * n)
1106    if unit == "second":
1107        return relativedelta(seconds=1 * n)
1108
1109    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_floor( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1112def date_floor(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1113    if unit == "year":
1114        return d.replace(month=1, day=1)
1115    if unit == "quarter":
1116        if d.month <= 3:
1117            return d.replace(month=1, day=1)
1118        elif d.month <= 6:
1119            return d.replace(month=4, day=1)
1120        elif d.month <= 9:
1121            return d.replace(month=7, day=1)
1122        else:
1123            return d.replace(month=10, day=1)
1124    if unit == "month":
1125        return d.replace(month=d.month, day=1)
1126    if unit == "week":
1127        # Assuming week starts on Monday (0) and ends on Sunday (6)
1128        return d - datetime.timedelta(days=d.weekday() - dialect.WEEK_OFFSET)
1129    if unit == "day":
1130        return d
1131
1132    raise UnsupportedUnit(f"Unsupported unit: {unit}")
def date_ceil( d: datetime.date, unit: str, dialect: sqlglot.dialects.dialect.Dialect) -> datetime.date:
1135def date_ceil(d: datetime.date, unit: str, dialect: Dialect) -> datetime.date:
1136    floor = date_floor(d, unit, dialect)
1137
1138    if floor == d:
1139        return d
1140
1141    return floor + interval(unit)
def boolean_literal(condition):
1144def boolean_literal(condition):
1145    return exp.true() if condition else exp.false()
def gen(expression: Any) -> str:
1174def gen(expression: t.Any) -> str:
1175    """Simple pseudo sql generator for quickly generating sortable and uniq strings.
1176
1177    Sorting and deduping sql is a necessary step for optimization. Calling the actual
1178    generator is expensive so we have a bare minimum sql generator here.
1179    """
1180    return Gen().gen(expression)

Simple pseudo sql generator for quickly generating sortable and uniq strings.

Sorting and deduping sql is a necessary step for optimization. Calling the actual generator is expensive so we have a bare minimum sql generator here.

class Gen:
1183class Gen:
1184    def __init__(self):
1185        self.stack = []
1186        self.sqls = []
1187
1188    def gen(self, expression: exp.Expression) -> str:
1189        self.stack = [expression]
1190        self.sqls.clear()
1191
1192        while self.stack:
1193            node = self.stack.pop()
1194
1195            if isinstance(node, exp.Expression):
1196                exp_handler_name = f"{node.key}_sql"
1197
1198                if hasattr(self, exp_handler_name):
1199                    getattr(self, exp_handler_name)(node)
1200                elif isinstance(node, exp.Func):
1201                    self._function(node)
1202                else:
1203                    key = node.key.upper()
1204                    self.stack.append(f"{key} " if self._args(node) else key)
1205            elif type(node) is list:
1206                for n in reversed(node):
1207                    if n is not None:
1208                        self.stack.extend((n, ","))
1209                if node:
1210                    self.stack.pop()
1211            else:
1212                if node is not None:
1213                    self.sqls.append(str(node))
1214
1215        return "".join(self.sqls)
1216
1217    def add_sql(self, e: exp.Add) -> None:
1218        self._binary(e, " + ")
1219
1220    def alias_sql(self, e: exp.Alias) -> None:
1221        self.stack.extend(
1222            (
1223                e.args.get("alias"),
1224                " AS ",
1225                e.args.get("this"),
1226            )
1227        )
1228
1229    def and_sql(self, e: exp.And) -> None:
1230        self._binary(e, " AND ")
1231
1232    def anonymous_sql(self, e: exp.Anonymous) -> None:
1233        this = e.this
1234        if isinstance(this, str):
1235            name = this.upper()
1236        elif isinstance(this, exp.Identifier):
1237            name = this.this
1238            name = f'"{name}"' if this.quoted else name.upper()
1239        else:
1240            raise ValueError(
1241                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1242            )
1243
1244        self.stack.extend(
1245            (
1246                ")",
1247                e.expressions,
1248                "(",
1249                name,
1250            )
1251        )
1252
1253    def between_sql(self, e: exp.Between) -> None:
1254        self.stack.extend(
1255            (
1256                e.args.get("high"),
1257                " AND ",
1258                e.args.get("low"),
1259                " BETWEEN ",
1260                e.this,
1261            )
1262        )
1263
1264    def boolean_sql(self, e: exp.Boolean) -> None:
1265        self.stack.append("TRUE" if e.this else "FALSE")
1266
1267    def bracket_sql(self, e: exp.Bracket) -> None:
1268        self.stack.extend(
1269            (
1270                "]",
1271                e.expressions,
1272                "[",
1273                e.this,
1274            )
1275        )
1276
1277    def column_sql(self, e: exp.Column) -> None:
1278        for p in reversed(e.parts):
1279            self.stack.extend((p, "."))
1280        self.stack.pop()
1281
1282    def datatype_sql(self, e: exp.DataType) -> None:
1283        self._args(e, 1)
1284        self.stack.append(f"{e.this.name} ")
1285
1286    def div_sql(self, e: exp.Div) -> None:
1287        self._binary(e, " / ")
1288
1289    def dot_sql(self, e: exp.Dot) -> None:
1290        self._binary(e, ".")
1291
1292    def eq_sql(self, e: exp.EQ) -> None:
1293        self._binary(e, " = ")
1294
1295    def from_sql(self, e: exp.From) -> None:
1296        self.stack.extend((e.this, "FROM "))
1297
1298    def gt_sql(self, e: exp.GT) -> None:
1299        self._binary(e, " > ")
1300
1301    def gte_sql(self, e: exp.GTE) -> None:
1302        self._binary(e, " >= ")
1303
1304    def identifier_sql(self, e: exp.Identifier) -> None:
1305        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
1306
1307    def ilike_sql(self, e: exp.ILike) -> None:
1308        self._binary(e, " ILIKE ")
1309
1310    def in_sql(self, e: exp.In) -> None:
1311        self.stack.append(")")
1312        self._args(e, 1)
1313        self.stack.extend(
1314            (
1315                "(",
1316                " IN ",
1317                e.this,
1318            )
1319        )
1320
1321    def intdiv_sql(self, e: exp.IntDiv) -> None:
1322        self._binary(e, " DIV ")
1323
1324    def is_sql(self, e: exp.Is) -> None:
1325        self._binary(e, " IS ")
1326
1327    def like_sql(self, e: exp.Like) -> None:
1328        self._binary(e, " Like ")
1329
1330    def literal_sql(self, e: exp.Literal) -> None:
1331        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
1332
1333    def lt_sql(self, e: exp.LT) -> None:
1334        self._binary(e, " < ")
1335
1336    def lte_sql(self, e: exp.LTE) -> None:
1337        self._binary(e, " <= ")
1338
1339    def mod_sql(self, e: exp.Mod) -> None:
1340        self._binary(e, " % ")
1341
1342    def mul_sql(self, e: exp.Mul) -> None:
1343        self._binary(e, " * ")
1344
1345    def neg_sql(self, e: exp.Neg) -> None:
1346        self._unary(e, "-")
1347
1348    def neq_sql(self, e: exp.NEQ) -> None:
1349        self._binary(e, " <> ")
1350
1351    def not_sql(self, e: exp.Not) -> None:
1352        self._unary(e, "NOT ")
1353
1354    def null_sql(self, e: exp.Null) -> None:
1355        self.stack.append("NULL")
1356
1357    def or_sql(self, e: exp.Or) -> None:
1358        self._binary(e, " OR ")
1359
1360    def paren_sql(self, e: exp.Paren) -> None:
1361        self.stack.extend(
1362            (
1363                ")",
1364                e.this,
1365                "(",
1366            )
1367        )
1368
1369    def sub_sql(self, e: exp.Sub) -> None:
1370        self._binary(e, " - ")
1371
1372    def subquery_sql(self, e: exp.Subquery) -> None:
1373        self._args(e, 2)
1374        alias = e.args.get("alias")
1375        if alias:
1376            self.stack.append(alias)
1377        self.stack.extend((")", e.this, "("))
1378
1379    def table_sql(self, e: exp.Table) -> None:
1380        self._args(e, 4)
1381        alias = e.args.get("alias")
1382        if alias:
1383            self.stack.append(alias)
1384        for p in reversed(e.parts):
1385            self.stack.extend((p, "."))
1386        self.stack.pop()
1387
1388    def tablealias_sql(self, e: exp.TableAlias) -> None:
1389        columns = e.columns
1390
1391        if columns:
1392            self.stack.extend((")", columns, "("))
1393
1394        self.stack.extend((e.this, " AS "))
1395
1396    def var_sql(self, e: exp.Var) -> None:
1397        self.stack.append(e.this)
1398
1399    def _binary(self, e: exp.Binary, op: str) -> None:
1400        self.stack.extend((e.expression, op, e.this))
1401
1402    def _unary(self, e: exp.Unary, op: str) -> None:
1403        self.stack.extend((e.this, op))
1404
1405    def _function(self, e: exp.Func) -> None:
1406        self.stack.extend(
1407            (
1408                ")",
1409                list(e.args.values()),
1410                "(",
1411                e.sql_name(),
1412            )
1413        )
1414
1415    def _args(self, node: exp.Expression, arg_index: int = 0) -> bool:
1416        kvs = []
1417        arg_types = list(node.arg_types)[arg_index:] if arg_index else node.arg_types
1418
1419        for k in arg_types or arg_types:
1420            v = node.args.get(k)
1421
1422            if v is not None:
1423                kvs.append([f":{k}", v])
1424        if kvs:
1425            self.stack.append(kvs)
1426            return True
1427        return False
stack
sqls
def gen(self, expression: sqlglot.expressions.Expression) -> str:
1188    def gen(self, expression: exp.Expression) -> str:
1189        self.stack = [expression]
1190        self.sqls.clear()
1191
1192        while self.stack:
1193            node = self.stack.pop()
1194
1195            if isinstance(node, exp.Expression):
1196                exp_handler_name = f"{node.key}_sql"
1197
1198                if hasattr(self, exp_handler_name):
1199                    getattr(self, exp_handler_name)(node)
1200                elif isinstance(node, exp.Func):
1201                    self._function(node)
1202                else:
1203                    key = node.key.upper()
1204                    self.stack.append(f"{key} " if self._args(node) else key)
1205            elif type(node) is list:
1206                for n in reversed(node):
1207                    if n is not None:
1208                        self.stack.extend((n, ","))
1209                if node:
1210                    self.stack.pop()
1211            else:
1212                if node is not None:
1213                    self.sqls.append(str(node))
1214
1215        return "".join(self.sqls)
def add_sql(self, e: sqlglot.expressions.Add) -> None:
1217    def add_sql(self, e: exp.Add) -> None:
1218        self._binary(e, " + ")
def alias_sql(self, e: sqlglot.expressions.Alias) -> None:
1220    def alias_sql(self, e: exp.Alias) -> None:
1221        self.stack.extend(
1222            (
1223                e.args.get("alias"),
1224                " AS ",
1225                e.args.get("this"),
1226            )
1227        )
def and_sql(self, e: sqlglot.expressions.And) -> None:
1229    def and_sql(self, e: exp.And) -> None:
1230        self._binary(e, " AND ")
def anonymous_sql(self, e: sqlglot.expressions.Anonymous) -> None:
1232    def anonymous_sql(self, e: exp.Anonymous) -> None:
1233        this = e.this
1234        if isinstance(this, str):
1235            name = this.upper()
1236        elif isinstance(this, exp.Identifier):
1237            name = this.this
1238            name = f'"{name}"' if this.quoted else name.upper()
1239        else:
1240            raise ValueError(
1241                f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
1242            )
1243
1244        self.stack.extend(
1245            (
1246                ")",
1247                e.expressions,
1248                "(",
1249                name,
1250            )
1251        )
def between_sql(self, e: sqlglot.expressions.Between) -> None:
1253    def between_sql(self, e: exp.Between) -> None:
1254        self.stack.extend(
1255            (
1256                e.args.get("high"),
1257                " AND ",
1258                e.args.get("low"),
1259                " BETWEEN ",
1260                e.this,
1261            )
1262        )
def boolean_sql(self, e: sqlglot.expressions.Boolean) -> None:
1264    def boolean_sql(self, e: exp.Boolean) -> None:
1265        self.stack.append("TRUE" if e.this else "FALSE")
def bracket_sql(self, e: sqlglot.expressions.Bracket) -> None:
1267    def bracket_sql(self, e: exp.Bracket) -> None:
1268        self.stack.extend(
1269            (
1270                "]",
1271                e.expressions,
1272                "[",
1273                e.this,
1274            )
1275        )
def column_sql(self, e: sqlglot.expressions.Column) -> None:
1277    def column_sql(self, e: exp.Column) -> None:
1278        for p in reversed(e.parts):
1279            self.stack.extend((p, "."))
1280        self.stack.pop()
def datatype_sql(self, e: sqlglot.expressions.DataType) -> None:
1282    def datatype_sql(self, e: exp.DataType) -> None:
1283        self._args(e, 1)
1284        self.stack.append(f"{e.this.name} ")
def div_sql(self, e: sqlglot.expressions.Div) -> None:
1286    def div_sql(self, e: exp.Div) -> None:
1287        self._binary(e, " / ")
def dot_sql(self, e: sqlglot.expressions.Dot) -> None:
1289    def dot_sql(self, e: exp.Dot) -> None:
1290        self._binary(e, ".")
def eq_sql(self, e: sqlglot.expressions.EQ) -> None:
1292    def eq_sql(self, e: exp.EQ) -> None:
1293        self._binary(e, " = ")
def from_sql(self, e: sqlglot.expressions.From) -> None:
1295    def from_sql(self, e: exp.From) -> None:
1296        self.stack.extend((e.this, "FROM "))
def gt_sql(self, e: sqlglot.expressions.GT) -> None:
1298    def gt_sql(self, e: exp.GT) -> None:
1299        self._binary(e, " > ")
def gte_sql(self, e: sqlglot.expressions.GTE) -> None:
1301    def gte_sql(self, e: exp.GTE) -> None:
1302        self._binary(e, " >= ")
def identifier_sql(self, e: sqlglot.expressions.Identifier) -> None:
1304    def identifier_sql(self, e: exp.Identifier) -> None:
1305        self.stack.append(f'"{e.this}"' if e.quoted else e.this)
def ilike_sql(self, e: sqlglot.expressions.ILike) -> None:
1307    def ilike_sql(self, e: exp.ILike) -> None:
1308        self._binary(e, " ILIKE ")
def in_sql(self, e: sqlglot.expressions.In) -> None:
1310    def in_sql(self, e: exp.In) -> None:
1311        self.stack.append(")")
1312        self._args(e, 1)
1313        self.stack.extend(
1314            (
1315                "(",
1316                " IN ",
1317                e.this,
1318            )
1319        )
def intdiv_sql(self, e: sqlglot.expressions.IntDiv) -> None:
1321    def intdiv_sql(self, e: exp.IntDiv) -> None:
1322        self._binary(e, " DIV ")
def is_sql(self, e: sqlglot.expressions.Is) -> None:
1324    def is_sql(self, e: exp.Is) -> None:
1325        self._binary(e, " IS ")
def like_sql(self, e: sqlglot.expressions.Like) -> None:
1327    def like_sql(self, e: exp.Like) -> None:
1328        self._binary(e, " Like ")
def literal_sql(self, e: sqlglot.expressions.Literal) -> None:
1330    def literal_sql(self, e: exp.Literal) -> None:
1331        self.stack.append(f"'{e.this}'" if e.is_string else e.this)
def lt_sql(self, e: sqlglot.expressions.LT) -> None:
1333    def lt_sql(self, e: exp.LT) -> None:
1334        self._binary(e, " < ")
def lte_sql(self, e: sqlglot.expressions.LTE) -> None:
1336    def lte_sql(self, e: exp.LTE) -> None:
1337        self._binary(e, " <= ")
def mod_sql(self, e: sqlglot.expressions.Mod) -> None:
1339    def mod_sql(self, e: exp.Mod) -> None:
1340        self._binary(e, " % ")
def mul_sql(self, e: sqlglot.expressions.Mul) -> None:
1342    def mul_sql(self, e: exp.Mul) -> None:
1343        self._binary(e, " * ")
def neg_sql(self, e: sqlglot.expressions.Neg) -> None:
1345    def neg_sql(self, e: exp.Neg) -> None:
1346        self._unary(e, "-")
def neq_sql(self, e: sqlglot.expressions.NEQ) -> None:
1348    def neq_sql(self, e: exp.NEQ) -> None:
1349        self._binary(e, " <> ")
def not_sql(self, e: sqlglot.expressions.Not) -> None:
1351    def not_sql(self, e: exp.Not) -> None:
1352        self._unary(e, "NOT ")
def null_sql(self, e: sqlglot.expressions.Null) -> None:
1354    def null_sql(self, e: exp.Null) -> None:
1355        self.stack.append("NULL")
def or_sql(self, e: sqlglot.expressions.Or) -> None:
1357    def or_sql(self, e: exp.Or) -> None:
1358        self._binary(e, " OR ")
def paren_sql(self, e: sqlglot.expressions.Paren) -> None:
1360    def paren_sql(self, e: exp.Paren) -> None:
1361        self.stack.extend(
1362            (
1363                ")",
1364                e.this,
1365                "(",
1366            )
1367        )
def sub_sql(self, e: sqlglot.expressions.Sub) -> None:
1369    def sub_sql(self, e: exp.Sub) -> None:
1370        self._binary(e, " - ")
def subquery_sql(self, e: sqlglot.expressions.Subquery) -> None:
1372    def subquery_sql(self, e: exp.Subquery) -> None:
1373        self._args(e, 2)
1374        alias = e.args.get("alias")
1375        if alias:
1376            self.stack.append(alias)
1377        self.stack.extend((")", e.this, "("))
def table_sql(self, e: sqlglot.expressions.Table) -> None:
1379    def table_sql(self, e: exp.Table) -> None:
1380        self._args(e, 4)
1381        alias = e.args.get("alias")
1382        if alias:
1383            self.stack.append(alias)
1384        for p in reversed(e.parts):
1385            self.stack.extend((p, "."))
1386        self.stack.pop()
def tablealias_sql(self, e: sqlglot.expressions.TableAlias) -> None:
1388    def tablealias_sql(self, e: exp.TableAlias) -> None:
1389        columns = e.columns
1390
1391        if columns:
1392            self.stack.extend((")", columns, "("))
1393
1394        self.stack.extend((e.this, " AS "))
def var_sql(self, e: sqlglot.expressions.Var) -> None:
1396    def var_sql(self, e: exp.Var) -> None:
1397        self.stack.append(e.this)