Getting Started with Tree Rewriter¶
A hands-on tutorial for building powerful tree transformations with minimal code.
Target Audience: Python developers familiar with basic data structures who want to learn term rewriting for expression simplification, compiler transformations, or symbolic computation.
What You Will Learn: - How to represent expressions as trees (S-expressions) - How to write transformation rules using pattern matching - How to apply rules until a fixed point is reached - How to build a complete arithmetic simplifier
Prerequisites: - Python 3.8 or later - Basic understanding of tuples and functions
Estimated Time: 30 minutes
Table of Contents¶
- Installation
- Core Concepts
- Your First Rule
- Pattern Matching
- Traversal Strategies
- Composing Rules
- Practical Example: Arithmetic Simplifier
- Next Steps
Installation¶
Install Tree Rewriter using pip:
Alternatively, since the entire library is about 100 lines of code, you can copy the source file directly into your project.
Verify the installation:
Core Concepts¶
Tree Rewriter is built on three simple ideas:
1. Trees as S-Expressions¶
Trees are represented as nested tuples (S-expressions). This format is simple, universal, and easy to work with:
# Atomic values (leaves)
5 # A number
'x' # A variable name
# Compound expressions (internal nodes)
('+', 'x', 5) # x + 5
('*', 2, ('+', 'x', 5)) # 2 * (x + 5)
('if', 'cond', 'then', 'else') # A conditional expression
The first element of a tuple is typically an operator or tag, and the remaining elements are its arguments.
2. Rules as Functions¶
A rule is simply a function that takes a tree and returns a (possibly transformed) tree:
def add_zero_rule(tree):
"""Transform x + 0 into x."""
if tree == ('+', 'x', 0):
return 'x'
return tree # No change
If the rule does not apply, it returns the tree unchanged.
3. Fixed-Point Rewriting¶
The rewrite function applies rules repeatedly until the tree stops changing:
from tree_rewriter import rewrite
def example_rule(tree):
if tree == ('step', 1):
return ('step', 2)
if tree == ('step', 2):
return ('step', 3)
return tree
result = rewrite(('step', 1), example_rule)
print(result) # ('step', 3)
The rewriter applies rules in order. When a rule changes the tree, it restarts from the first rule. When no rule applies, the tree has reached a "fixed point" and rewriting stops.
Your First Rule¶
Let us create a simple rule that eliminates addition by zero. Instead of writing raw functions, Tree Rewriter provides a fluent API using when and then:
from tree_rewriter import rewrite, when, _
# Create a rule: when we see ('+', 0, something), return that something
rule = when('+', 0, _).then(lambda x: x)
# Test it
expr = ('+', 0, 'x')
result = rewrite(expr, rule)
print(result) # 'x'
Breaking this down:
when('+', 0, _)defines a pattern to match_is a wildcard that matches anything.then(lambda x: x)specifies what to produce when the pattern matches- The wildcard binds to
xin the lambda
You can also return a constant value:
# Multiplication by zero always yields zero
rule = when('*', 0, _).then(0)
result = rewrite(('*', 0, 'anything'), rule)
print(result) # 0
Pattern Matching¶
Tree Rewriter provides several pattern matching features that make rules expressive and readable.
Wildcards¶
The underscore _ matches any single value:
from tree_rewriter import when, _
# Match any binary addition
rule = when('+', _, _).then(lambda a, b: f"Adding {a} and {b}")
Each wildcard binds to the next parameter in the lambda, in order from left to right.
Named Variables¶
Use $name syntax to create named bindings. Named variables must match consistently:
# x - x = 0 (the same value on both sides)
rule = when('-', '$x', '$x').then(0)
# This matches because both arguments are 'y'
rewrite(('-', 'y', 'y'), rule) # Returns 0
# This does NOT match because arguments differ
rewrite(('-', 'a', 'b'), rule) # Returns ('-', 'a', 'b') unchanged
Named variables are passed to the lambda in the order they first appear:
# Swap two values
rule = when('swap', '$a', '$b').then(lambda a, b: ('swap', b, a))
rewrite(('swap', 1, 2), rule) # Returns ('swap', 2, 1)
Predicates¶
Instead of matching exact values, you can use predicate functions that return True or False:
from tree_rewriter import when, is_literal, is_type
# is_literal matches numbers, booleans, None, and complex numbers
rule = when('+', is_literal, is_literal).then(lambda a, b: a + b)
rewrite(('+', 3, 4), rule) # Returns 7
rewrite(('+', 'x', 4), rule) # Returns ('+', 'x', 4) - no match
# is_type creates custom type predicates
is_int = is_type(int)
is_number = is_type(int, float)
You can also write inline predicates:
# Match only positive numbers
is_positive = lambda x: isinstance(x, (int, float)) and x > 0
rule = when('double', is_positive).then(lambda x: x * 2)
Guards (Where Clauses)¶
Add extra conditions with .where():
# Division, but only when divisor is not zero
rule = when('/', _, _).where(lambda a, b: b != 0).then(lambda a, b: a / b)
rewrite(('/', 10, 2), rule) # Returns 5.0
rewrite(('/', 10, 0), rule) # Returns ('/', 10, 0) - guard failed
Guards are checked after the pattern matches. The lambda receives all bound values in order.
Nested Patterns¶
Patterns can match nested structures:
# Match double negation: not(not(x)) => x
rule = when('not', ('not', '$x')).then(lambda x: x)
rewrite(('not', ('not', 'a')), rule) # Returns 'a'
# Match a specific nested shape
rule = when('first', ('pair', '$a', '$b')).then(lambda a, b: a)
rewrite(('first', ('pair', 1, 2)), rule) # Returns 1
Traversal Strategies¶
By default, rules only apply to the root of the tree. To transform nested subexpressions, use traversal strategies.
Bottom-Up Traversal¶
The bottom_up function wraps a rule to apply it throughout a tree, starting from the leaves and working up:
from tree_rewriter import rewrite, when, _, bottom_up
# This rule only matches at root level
rule = when('+', 0, _).then(lambda x: x)
# Without bottom_up: inner expression not transformed
expr = ('*', ('+', 0, 'x'), 2)
rewrite(expr, rule) # Still ('*', ('+', 0, 'x'), 2)
# With bottom_up: transforms the nested ('+', 0, 'x') first
rewrite(expr, bottom_up(rule)) # Returns ('*', 'x', 2)
Pro Tip: When building a simplifier, always wrap your rules with bottom_up:
rules = [
when('+', 0, _).then(lambda x: x),
when('*', 1, _).then(lambda x: x),
]
# Apply all rules bottom-up
result = rewrite(expr, *[bottom_up(r) for r in rules])
Why Bottom-Up?¶
Consider simplifying ('+', ('+', 0, 'x'), 0):
-
Bottom-up transforms the inner
('+', 0, 'x')to'x'first, giving('+', 'x', 0), then transforms that to'x'. -
Top-down would try to match the outer expression first, but
('+', ('+', 0, 'x'), 0)does not match('+', 0, _)because the first argument is not0.
Bottom-up ensures simpler subexpressions are created before attempting to match outer patterns.
Composing Rules¶
Tree Rewriter provides several ways to combine rules.
Multiple Rules with rewrite¶
Pass multiple rules to rewrite. Rules are tried in order, and when one succeeds, the process restarts from the first rule:
from tree_rewriter import rewrite, when, _, bottom_up
rules = [
when('+', 0, _).then(lambda x: x), # 0 + x = x
when('+', _, 0).then(lambda x: x), # x + 0 = x
when('*', 0, _).then(0), # 0 * x = 0
]
expr = ('*', ('+', 'y', 0), 0)
result = rewrite(expr, *[bottom_up(r) for r in rules])
print(result) # 0
Commutative Helper¶
Many operators are commutative (order does not matter). The commutative helper generates both orderings:
from tree_rewriter import commutative
# Instead of writing two rules:
# when('+', 0, _).then(lambda x: x)
# when('+', _, 0).then(lambda x: x)
# Write once:
rules = commutative('+', 0, lambda x: x)
# Returns a list of two rules
Use the spread operator to include them in your rule list:
all_rules = [
*commutative('+', 0, lambda x: x), # 0 + x = x + 0 = x
*commutative('*', 1, lambda x: x), # 1 * x = x * 1 = x
*commutative('*', 0, 0), # 0 * x = x * 0 = 0
]
first - Try Rules in Order¶
The first combinator tries rules in order and returns the result of the first one that changes the tree:
from tree_rewriter import first, when, _
combined = first(
when('a', _).then(lambda x: ('matched-a', x)),
when('b', _).then(lambda x: ('matched-b', x)),
)
rewrite(('a', 1), combined) # ('matched-a', 1)
rewrite(('b', 2), combined) # ('matched-b', 2)
rewrite(('c', 3), combined) # ('c', 3) - no match
all - Apply Rules in Sequence¶
The all combinator applies every rule in sequence, passing the result of each to the next:
from tree_rewriter import all, when, _
step1 = when('val', _).then(lambda x: ('step1', x))
step2 = when('step1', _).then(lambda x: ('step2', x))
pipeline = all(step1, step2)
rewrite(('val', 'data'), pipeline) # ('step2', 'data')
Practical Example: Arithmetic Simplifier¶
Let us build a complete arithmetic expression simplifier that demonstrates all the concepts.
from tree_rewriter import rewrite, when, _, bottom_up, commutative, is_literal
# Define simplification rules
simplify_rules = [
# Identity elements: x + 0 = x, x * 1 = x
*commutative('+', 0, lambda x: x),
*commutative('*', 1, lambda x: x),
# Absorbing element: x * 0 = 0
*commutative('*', 0, 0),
# Constant folding: compute operations on known values
when('+', is_literal, is_literal).then(lambda a, b: a + b),
when('-', is_literal, is_literal).then(lambda a, b: a - b),
when('*', is_literal, is_literal).then(lambda a, b: a * b),
when('/', is_literal, is_literal).where(lambda a, b: b != 0).then(
lambda a, b: a / b
),
# Algebraic identities
when('-', '$x', '$x').then(0), # x - x = 0
when('/', '$x', '$x').where(lambda x: x != 0).then(1), # x / x = 1
]
def simplify(expr):
"""Simplify an arithmetic expression."""
return rewrite(expr, *[bottom_up(r) for r in simplify_rules])
# Test the simplifier
examples = [
('+', 'x', 0), # x + 0 => x
('*', 1, 'y'), # 1 * y => y
('*', 0, ('+', 'a', 'b')), # 0 * (a + b) => 0
('+', 2, 3), # 2 + 3 => 5
('*', 4, 5), # 4 * 5 => 20
('-', 'x', 'x'), # x - x => 0
('/', 'y', 'y'), # y / y => 1
('+', ('*', 2, 3), ('*', 4, 5)), # (2*3) + (4*5) => 26
]
print("Arithmetic Simplifier")
print("=" * 50)
for expr in examples:
result = simplify(expr)
print(f"{str(expr):35} => {result}")
Expected output:
Arithmetic Simplifier
==================================================
('+', 'x', 0) => x
('*', 1, 'y') => y
('*', 0, ('+', 'a', 'b')) => 0
('+', 2, 3) => 5
('*', 4, 5) => 20
('-', 'x', 'x') => 0
('/', 'y', 'y') => 1
('+', ('*', 2, 3), ('*', 4, 5)) => 26
Extending the Simplifier¶
You can easily add more rules:
extended_rules = simplify_rules + [
# Double negation: --x => x
when('neg', ('neg', '$x')).then(lambda x: x),
# Distribution: could add x + x => 2 * x
when('+', '$x', '$x').then(lambda x: ('*', 2, x)),
# Power rules
when('^', '$x', 0).then(1), # x^0 = 1
when('^', '$x', 1).then(lambda x: x), # x^1 = x
]
Next Steps¶
Now that you understand the basics, explore these topics:
Additional Examples¶
The examples/ directory contains more sophisticated demonstrations:
- boolean_algebra.py - Complete boolean algebra simplification with De Morgan's laws
- pattern_matching.py - Exhaustive demonstration of all pattern features
- calculus_differentiation.py - Symbolic differentiation
- css_optimizer.py - Real-world CSS optimization
Cookbook Patterns¶
The README includes useful patterns that stay out of the core API:
# Local fixed-point at a node
def repeat(rule):
def r(t):
while True:
new = rule(t)
if new == t:
return t
t = new
return r
# Top-down traversal
def top_down(rule):
def walk(t):
t2 = rule(t)
if isinstance(t2, tuple):
t2 = (t2[0],) + tuple(walk(ch) for ch in t2[1:])
return rule(t2)
return walk
# Commutative normalization (sort arguments)
def normalize_commutative(op_name):
return when(op_name, _, _).then(
lambda a, b: (op_name,) + tuple(sorted((a, b), key=str))
)
Custom Predicates¶
Create domain-specific predicates for cleaner rules:
is_zero = lambda x: x == 0
is_one = lambda x: x == 1
is_variable = lambda x: isinstance(x, str) and not x.startswith('$')
is_operation = lambda name: (lambda t: isinstance(t, tuple) and t[0] == name)
Design Philosophy¶
Remember the core philosophy:
- Algorithms (like differentiation): Write as recursive functions
- Pattern transformations (like simplification): Write as rewrite rules
The rewriter stays simple. Your rules encode the complexity.
Common Issues¶
Rule Does Not Match Nested Expressions¶
Problem: Your rule works at the top level but not inside larger expressions.
Solution: Wrap the rule with bottom_up:
Named Variable Does Not Capture Value¶
Problem: Using '$x' in .then() returns the literal string '$x'.
Solution: Use a lambda to access the captured value:
# Wrong:
when('-', '$x', '$x').then('$x') # Returns '$x' literally
# Correct:
when('-', '$x', '$x').then(lambda x: x) # Returns the captured value
Guard Not Working¶
Problem: The .where() guard never seems to match.
Solution: Ensure the lambda parameter count matches the number of bound values:
# If pattern has two wildcards, guard needs two parameters
when('/', _, _).where(lambda a, b: b != 0).then(lambda a, b: a / b)
# If pattern has one wildcard, guard needs one parameter
when('sqrt', _).where(lambda x: x >= 0).then(lambda x: x ** 0.5)
Infinite Loop¶
Problem: The rewriter never terminates.
Solution: Ensure your rules make "progress" toward a fixed point. Avoid rules that can cycle:
# Bad: these rules cycle forever
when('a', _).then(lambda x: ('b', x))
when('b', _).then(lambda x: ('a', x))
# Good: rules reduce complexity or reach terminal forms
when('double', is_literal).then(lambda x: x * 2)
Quick Reference¶
| Import | Purpose |
|---|---|
rewrite |
Apply rules until fixed point |
when |
Create pattern-matching rules |
_ |
Wildcard pattern |
bottom_up |
Apply rule throughout tree |
commutative |
Create rules for commutative ops |
is_literal |
Predicate for numbers/bools/None |
is_type |
Create type-checking predicate |
first |
Try rules, return first match |
all |
Apply rules in sequence |
Pattern Syntax¶
| Pattern | Matches |
|---|---|
5 |
Exact value 5 |
'+' |
Exact string '+' |
_ |
Anything (wildcard) |
'$x' |
Anything, named binding |
is_literal |
Numbers, bools, None |
lambda x: x > 0 |
Custom predicate |
('+', _, _) |
Tuple with '+' and two args |