Skip to content

Getting Started with Tree Rewriter

A hands-on tutorial for building powerful tree transformations with minimal code.

Target Audience: Python developers familiar with basic data structures who want to learn term rewriting for expression simplification, compiler transformations, or symbolic computation.

What You Will Learn: - How to represent expressions as trees (S-expressions) - How to write transformation rules using pattern matching - How to apply rules until a fixed point is reached - How to build a complete arithmetic simplifier

Prerequisites: - Python 3.8 or later - Basic understanding of tuples and functions

Estimated Time: 30 minutes


Table of Contents

  1. Installation
  2. Core Concepts
  3. Your First Rule
  4. Pattern Matching
  5. Traversal Strategies
  6. Composing Rules
  7. Practical Example: Arithmetic Simplifier
  8. Next Steps

Installation

Install Tree Rewriter using pip:

pip install tree-rewriter

Alternatively, since the entire library is about 100 lines of code, you can copy the source file directly into your project.

Verify the installation:

from tree_rewriter import rewrite, when, _
print("Tree Rewriter installed successfully!")

Core Concepts

Tree Rewriter is built on three simple ideas:

1. Trees as S-Expressions

Trees are represented as nested tuples (S-expressions). This format is simple, universal, and easy to work with:

# Atomic values (leaves)
5                        # A number
'x'                      # A variable name

# Compound expressions (internal nodes)
('+', 'x', 5)            # x + 5
('*', 2, ('+', 'x', 5))  # 2 * (x + 5)
('if', 'cond', 'then', 'else')  # A conditional expression

The first element of a tuple is typically an operator or tag, and the remaining elements are its arguments.

2. Rules as Functions

A rule is simply a function that takes a tree and returns a (possibly transformed) tree:

def add_zero_rule(tree):
    """Transform x + 0 into x."""
    if tree == ('+', 'x', 0):
        return 'x'
    return tree  # No change

If the rule does not apply, it returns the tree unchanged.

3. Fixed-Point Rewriting

The rewrite function applies rules repeatedly until the tree stops changing:

from tree_rewriter import rewrite

def example_rule(tree):
    if tree == ('step', 1):
        return ('step', 2)
    if tree == ('step', 2):
        return ('step', 3)
    return tree

result = rewrite(('step', 1), example_rule)
print(result)  # ('step', 3)

The rewriter applies rules in order. When a rule changes the tree, it restarts from the first rule. When no rule applies, the tree has reached a "fixed point" and rewriting stops.


Your First Rule

Let us create a simple rule that eliminates addition by zero. Instead of writing raw functions, Tree Rewriter provides a fluent API using when and then:

from tree_rewriter import rewrite, when, _

# Create a rule: when we see ('+', 0, something), return that something
rule = when('+', 0, _).then(lambda x: x)

# Test it
expr = ('+', 0, 'x')
result = rewrite(expr, rule)
print(result)  # 'x'

Breaking this down:

  • when('+', 0, _) defines a pattern to match
  • _ is a wildcard that matches anything
  • .then(lambda x: x) specifies what to produce when the pattern matches
  • The wildcard binds to x in the lambda

You can also return a constant value:

# Multiplication by zero always yields zero
rule = when('*', 0, _).then(0)

result = rewrite(('*', 0, 'anything'), rule)
print(result)  # 0

Pattern Matching

Tree Rewriter provides several pattern matching features that make rules expressive and readable.

Wildcards

The underscore _ matches any single value:

from tree_rewriter import when, _

# Match any binary addition
rule = when('+', _, _).then(lambda a, b: f"Adding {a} and {b}")

Each wildcard binds to the next parameter in the lambda, in order from left to right.

Named Variables

Use $name syntax to create named bindings. Named variables must match consistently:

# x - x = 0 (the same value on both sides)
rule = when('-', '$x', '$x').then(0)

# This matches because both arguments are 'y'
rewrite(('-', 'y', 'y'), rule)  # Returns 0

# This does NOT match because arguments differ
rewrite(('-', 'a', 'b'), rule)  # Returns ('-', 'a', 'b') unchanged

Named variables are passed to the lambda in the order they first appear:

# Swap two values
rule = when('swap', '$a', '$b').then(lambda a, b: ('swap', b, a))

rewrite(('swap', 1, 2), rule)  # Returns ('swap', 2, 1)

Predicates

Instead of matching exact values, you can use predicate functions that return True or False:

from tree_rewriter import when, is_literal, is_type

# is_literal matches numbers, booleans, None, and complex numbers
rule = when('+', is_literal, is_literal).then(lambda a, b: a + b)

rewrite(('+', 3, 4), rule)  # Returns 7
rewrite(('+', 'x', 4), rule)  # Returns ('+', 'x', 4) - no match

# is_type creates custom type predicates
is_int = is_type(int)
is_number = is_type(int, float)

You can also write inline predicates:

# Match only positive numbers
is_positive = lambda x: isinstance(x, (int, float)) and x > 0

rule = when('double', is_positive).then(lambda x: x * 2)

Guards (Where Clauses)

Add extra conditions with .where():

# Division, but only when divisor is not zero
rule = when('/', _, _).where(lambda a, b: b != 0).then(lambda a, b: a / b)

rewrite(('/', 10, 2), rule)  # Returns 5.0
rewrite(('/', 10, 0), rule)  # Returns ('/', 10, 0) - guard failed

Guards are checked after the pattern matches. The lambda receives all bound values in order.

Nested Patterns

Patterns can match nested structures:

# Match double negation: not(not(x)) => x
rule = when('not', ('not', '$x')).then(lambda x: x)

rewrite(('not', ('not', 'a')), rule)  # Returns 'a'

# Match a specific nested shape
rule = when('first', ('pair', '$a', '$b')).then(lambda a, b: a)

rewrite(('first', ('pair', 1, 2)), rule)  # Returns 1

Traversal Strategies

By default, rules only apply to the root of the tree. To transform nested subexpressions, use traversal strategies.

Bottom-Up Traversal

The bottom_up function wraps a rule to apply it throughout a tree, starting from the leaves and working up:

from tree_rewriter import rewrite, when, _, bottom_up

# This rule only matches at root level
rule = when('+', 0, _).then(lambda x: x)

# Without bottom_up: inner expression not transformed
expr = ('*', ('+', 0, 'x'), 2)
rewrite(expr, rule)  # Still ('*', ('+', 0, 'x'), 2)

# With bottom_up: transforms the nested ('+', 0, 'x') first
rewrite(expr, bottom_up(rule))  # Returns ('*', 'x', 2)

Pro Tip: When building a simplifier, always wrap your rules with bottom_up:

rules = [
    when('+', 0, _).then(lambda x: x),
    when('*', 1, _).then(lambda x: x),
]

# Apply all rules bottom-up
result = rewrite(expr, *[bottom_up(r) for r in rules])

Why Bottom-Up?

Consider simplifying ('+', ('+', 0, 'x'), 0):

  1. Bottom-up transforms the inner ('+', 0, 'x') to 'x' first, giving ('+', 'x', 0), then transforms that to 'x'.

  2. Top-down would try to match the outer expression first, but ('+', ('+', 0, 'x'), 0) does not match ('+', 0, _) because the first argument is not 0.

Bottom-up ensures simpler subexpressions are created before attempting to match outer patterns.


Composing Rules

Tree Rewriter provides several ways to combine rules.

Multiple Rules with rewrite

Pass multiple rules to rewrite. Rules are tried in order, and when one succeeds, the process restarts from the first rule:

from tree_rewriter import rewrite, when, _, bottom_up

rules = [
    when('+', 0, _).then(lambda x: x),   # 0 + x = x
    when('+', _, 0).then(lambda x: x),   # x + 0 = x
    when('*', 0, _).then(0),             # 0 * x = 0
]

expr = ('*', ('+', 'y', 0), 0)
result = rewrite(expr, *[bottom_up(r) for r in rules])
print(result)  # 0

Commutative Helper

Many operators are commutative (order does not matter). The commutative helper generates both orderings:

from tree_rewriter import commutative

# Instead of writing two rules:
# when('+', 0, _).then(lambda x: x)
# when('+', _, 0).then(lambda x: x)

# Write once:
rules = commutative('+', 0, lambda x: x)
# Returns a list of two rules

Use the spread operator to include them in your rule list:

all_rules = [
    *commutative('+', 0, lambda x: x),  # 0 + x = x + 0 = x
    *commutative('*', 1, lambda x: x),  # 1 * x = x * 1 = x
    *commutative('*', 0, 0),            # 0 * x = x * 0 = 0
]

first - Try Rules in Order

The first combinator tries rules in order and returns the result of the first one that changes the tree:

from tree_rewriter import first, when, _

combined = first(
    when('a', _).then(lambda x: ('matched-a', x)),
    when('b', _).then(lambda x: ('matched-b', x)),
)

rewrite(('a', 1), combined)  # ('matched-a', 1)
rewrite(('b', 2), combined)  # ('matched-b', 2)
rewrite(('c', 3), combined)  # ('c', 3) - no match

all - Apply Rules in Sequence

The all combinator applies every rule in sequence, passing the result of each to the next:

from tree_rewriter import all, when, _

step1 = when('val', _).then(lambda x: ('step1', x))
step2 = when('step1', _).then(lambda x: ('step2', x))

pipeline = all(step1, step2)

rewrite(('val', 'data'), pipeline)  # ('step2', 'data')

Practical Example: Arithmetic Simplifier

Let us build a complete arithmetic expression simplifier that demonstrates all the concepts.

from tree_rewriter import rewrite, when, _, bottom_up, commutative, is_literal

# Define simplification rules
simplify_rules = [
    # Identity elements: x + 0 = x, x * 1 = x
    *commutative('+', 0, lambda x: x),
    *commutative('*', 1, lambda x: x),

    # Absorbing element: x * 0 = 0
    *commutative('*', 0, 0),

    # Constant folding: compute operations on known values
    when('+', is_literal, is_literal).then(lambda a, b: a + b),
    when('-', is_literal, is_literal).then(lambda a, b: a - b),
    when('*', is_literal, is_literal).then(lambda a, b: a * b),
    when('/', is_literal, is_literal).where(lambda a, b: b != 0).then(
        lambda a, b: a / b
    ),

    # Algebraic identities
    when('-', '$x', '$x').then(0),                              # x - x = 0
    when('/', '$x', '$x').where(lambda x: x != 0).then(1),      # x / x = 1
]

def simplify(expr):
    """Simplify an arithmetic expression."""
    return rewrite(expr, *[bottom_up(r) for r in simplify_rules])


# Test the simplifier
examples = [
    ('+', 'x', 0),                    # x + 0 => x
    ('*', 1, 'y'),                    # 1 * y => y
    ('*', 0, ('+', 'a', 'b')),        # 0 * (a + b) => 0
    ('+', 2, 3),                      # 2 + 3 => 5
    ('*', 4, 5),                      # 4 * 5 => 20
    ('-', 'x', 'x'),                  # x - x => 0
    ('/', 'y', 'y'),                  # y / y => 1
    ('+', ('*', 2, 3), ('*', 4, 5)),  # (2*3) + (4*5) => 26
]

print("Arithmetic Simplifier")
print("=" * 50)

for expr in examples:
    result = simplify(expr)
    print(f"{str(expr):35} => {result}")

Expected output:

Arithmetic Simplifier
==================================================
('+', 'x', 0)                       => x
('*', 1, 'y')                       => y
('*', 0, ('+', 'a', 'b'))           => 0
('+', 2, 3)                         => 5
('*', 4, 5)                         => 20
('-', 'x', 'x')                     => 0
('/', 'y', 'y')                     => 1
('+', ('*', 2, 3), ('*', 4, 5))     => 26

Extending the Simplifier

You can easily add more rules:

extended_rules = simplify_rules + [
    # Double negation: --x => x
    when('neg', ('neg', '$x')).then(lambda x: x),

    # Distribution: could add x + x => 2 * x
    when('+', '$x', '$x').then(lambda x: ('*', 2, x)),

    # Power rules
    when('^', '$x', 0).then(1),         # x^0 = 1
    when('^', '$x', 1).then(lambda x: x),  # x^1 = x
]

Next Steps

Now that you understand the basics, explore these topics:

Additional Examples

The examples/ directory contains more sophisticated demonstrations:

  • boolean_algebra.py - Complete boolean algebra simplification with De Morgan's laws
  • pattern_matching.py - Exhaustive demonstration of all pattern features
  • calculus_differentiation.py - Symbolic differentiation
  • css_optimizer.py - Real-world CSS optimization

Cookbook Patterns

The README includes useful patterns that stay out of the core API:

# Local fixed-point at a node
def repeat(rule):
    def r(t):
        while True:
            new = rule(t)
            if new == t:
                return t
            t = new
    return r

# Top-down traversal
def top_down(rule):
    def walk(t):
        t2 = rule(t)
        if isinstance(t2, tuple):
            t2 = (t2[0],) + tuple(walk(ch) for ch in t2[1:])
        return rule(t2)
    return walk

# Commutative normalization (sort arguments)
def normalize_commutative(op_name):
    return when(op_name, _, _).then(
        lambda a, b: (op_name,) + tuple(sorted((a, b), key=str))
    )

Custom Predicates

Create domain-specific predicates for cleaner rules:

is_zero = lambda x: x == 0
is_one = lambda x: x == 1
is_variable = lambda x: isinstance(x, str) and not x.startswith('$')
is_operation = lambda name: (lambda t: isinstance(t, tuple) and t[0] == name)

Design Philosophy

Remember the core philosophy:

  • Algorithms (like differentiation): Write as recursive functions
  • Pattern transformations (like simplification): Write as rewrite rules

The rewriter stays simple. Your rules encode the complexity.


Common Issues

Rule Does Not Match Nested Expressions

Problem: Your rule works at the top level but not inside larger expressions.

Solution: Wrap the rule with bottom_up:

# Instead of:
rewrite(expr, rule)

# Use:
rewrite(expr, bottom_up(rule))

Named Variable Does Not Capture Value

Problem: Using '$x' in .then() returns the literal string '$x'.

Solution: Use a lambda to access the captured value:

# Wrong:
when('-', '$x', '$x').then('$x')  # Returns '$x' literally

# Correct:
when('-', '$x', '$x').then(lambda x: x)  # Returns the captured value

Guard Not Working

Problem: The .where() guard never seems to match.

Solution: Ensure the lambda parameter count matches the number of bound values:

# If pattern has two wildcards, guard needs two parameters
when('/', _, _).where(lambda a, b: b != 0).then(lambda a, b: a / b)

# If pattern has one wildcard, guard needs one parameter
when('sqrt', _).where(lambda x: x >= 0).then(lambda x: x ** 0.5)

Infinite Loop

Problem: The rewriter never terminates.

Solution: Ensure your rules make "progress" toward a fixed point. Avoid rules that can cycle:

# Bad: these rules cycle forever
when('a', _).then(lambda x: ('b', x))
when('b', _).then(lambda x: ('a', x))

# Good: rules reduce complexity or reach terminal forms
when('double', is_literal).then(lambda x: x * 2)

Quick Reference

Import Purpose
rewrite Apply rules until fixed point
when Create pattern-matching rules
_ Wildcard pattern
bottom_up Apply rule throughout tree
commutative Create rules for commutative ops
is_literal Predicate for numbers/bools/None
is_type Create type-checking predicate
first Try rules, return first match
all Apply rules in sequence

Pattern Syntax

Pattern Matches
5 Exact value 5
'+' Exact string '+'
_ Anything (wildcard)
'$x' Anything, named binding
is_literal Numbers, bools, None
lambda x: x > 0 Custom predicate
('+', _, _) Tuple with '+' and two args

For more information, see the README and DESIGN.md.