Introduction¶
The ONNX Rewriter tool provides the user with the functionality to replace certain patterns in an ONNX graph with another pattern based on conditional rewrite rules provided by the user.
Usage¶
There are three main components needed when rewriting patterns in the graph:
target_pattern
: Original pattern to match against. This pattern is written as a function using ONNXScript-like operators.replacement_pattern
: Pattern to replace the original pattern with. This pattern is also written as a function using ONNXScript-like operators.match_condition
(optional) : Pattern rewrite will occur only if the match condition is satisfied.
Pattern Options¶
When defining patterns, you can use several special options to control how patterns match and what they produce:
_allow_other_attributes
: Controls whether the pattern allows additional attributes not specified in the pattern (default: True)_allow_other_inputs
: Controls whether the pattern allows additional inputs beyond those specified (default: False)_domain
: Specifies the operator domain for matching or creating operations_outputs
: Specifies the number and optionally names of outputs from an operation
These options are documented in detail in the following sections.
A Simple Example¶
An simple example demonstrating the usage of this functionality using the GELU
activation function:
GELU
activation function can be computed using a Gauss Error Function using the given formula:
We will show how we can find a subgraph matching this computation and replace it by a call to the function.
Firstly, include all the rewriter relevant imports.
from onnxscript.rewriter import pattern
from onnxscript import ir
Then create a target pattern that needs to be replaced using onnxscript operators.
def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
After this, create a replacement pattern that consists of the GELU onnxscript operator.
def gelu(op, x: ir.Value):
return op.Gelu(x, _domain="com.microsoft")
Note
The inputs to the replacement pattern are of type ir.Value
. For detailed usage of ir.Value
refer to the ir.Value
class.
For this example, we do not require a match_condition
so that option is skipped for now. Then the rewrite rule is created using the RewriteRule
function.
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
It is more convenient to organize more complex rewrite-rules as a class. The above rule can be alternatively expressed as below.
class ErfGeluFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x):
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
def rewrite(self, op, x):
return op.Gelu(x, _domain="com.microsoft")
The corresponding rewrite-rule can be obtained as below:
erf_gelu_rule_from_class = ErfGeluFusion.rule()
Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The rewriter.rewrite (model, pattern_rewrite_rules)
call applies the specified rewrite rules to the given model.
model
: The original model on which the pattern rewrite rules are to be applied. This is of typeir.Model
oronnx.ModelProto
. If the model is anir.Model
, the rewriter applies the changes in-place, modifying the input model. If it is anModelProto
, the rewriter returns a newModelProto
representing the transformed model.pattern_rewrite_rules
: This parameter either aSequence[PatternRewriteRule]
or aRewriteRuleSet
.
Note
For steps on how to create and use Rule-sets, refer to the example in the section Creating a rule-set with different patterns.
The snippet below below demonstrates how to use the rewriter.rewrite
call for the rewrite rule created above:
def apply_rewrite(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=[rule],
)
return model_with_rewrite_applied
The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended.
Specifying attributes in the pattern¶
This section demonstrates the use of attribute values in pattern-based rewriting.
First, write a target pattern and replacement pattern in a similar way to the previous examples.
The example pattern below will match successfully only against Dropout nodes with the
attribute value training_mode
set to False
.
The _allow_other_attributes
option allows the pattern to match nodes that have additional attributes
not specified in the pattern. If it is set to False
, then the node must have only the specified
attribute values, and no other attributes, for a successful match. The default value for this
option is True
.
def add_pattern(op, input):
return op.Dropout(input, training_mode=False, _allow_other_attributes=True)
def add_replacement(op, input, **_):
return op.Identity(input)
def apply_rewrite(model):
# Create rewrite rules
add_rule = pattern.RewriteRule(
add_pattern, # target pattern
add_replacement, # replacement pattern
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([add_rule])
# Apply rewrite while passing match_condition
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite
Specifying variable inputs in the pattern¶
This section demonstrates the use of the _allow_other_inputs
option in pattern-based rewriting.
The _allow_other_inputs
option allows the pattern to match nodes that have additional inputs
beyond those specified in the pattern. If it is set to False
(the default), then the node must
have exactly the specified inputs for a successful match. If set to True
, the pattern will
match nodes that have the specified inputs plus any number of additional inputs.
This is particularly useful when matching operations like Conv
that can have optional inputs
(such as bias), or when creating generic patterns that should work with various input configurations.
def conv_pattern(op, input, weight):
# Pattern to match Conv operations, allowing additional inputs like bias
# _allow_other_inputs=True allows the pattern to match Conv with bias (3 inputs)
# even though we only specify 2 inputs in the pattern
return op.Conv(input, weight, _allow_other_inputs=True)
def conv_replacement(op, input, weight, **_):
# Replace with a custom operation in a different domain
return op.OptimizedConv(input, weight, _domain="custom.domain")
def apply_rewrite(model):
# Create rewrite rules
conv_rule = pattern.RewriteRule(
conv_pattern, # target pattern
conv_replacement, # replacement pattern
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([conv_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite
In this example, the pattern matches Conv
operations with any number of inputs. A Conv
operation
might have 2 inputs (input and weight) or 3 inputs (input, weight, and bias). By setting
_allow_other_inputs=True
, our pattern will match both cases even though we only specify 2 inputs
in the pattern definition.
Specifying domains in the pattern¶
This section demonstrates the use of the _domain
option in pattern-based rewriting.
The _domain
option allows you to specify which operator domain the pattern should match against,
and also allows you to create replacement operations in specific domains.
ONNX operators can belong to different domains:
The default ONNX domain (empty string or “ai.onnx”)
Custom domains like “com.microsoft” for Microsoft-specific operations
User-defined domains for custom operations
Matching operations from a specific domain¶
def custom_relu_pattern(op, input):
# Pattern to match Relu operations from a specific domain
# _domain="custom.domain" specifies we only want to match operations from this domain
return op.Relu(input, _domain="custom.domain")
In this pattern, _domain="custom.domain"
ensures that only Relu
operations from the
“custom.domain” domain will be matched, not standard ONNX Relu
operations.
Creating replacement operations in a specific domain¶
def microsoft_relu_replacement(op, input, **_):
# Replace with operation in Microsoft's domain
return op.OptimizedRelu(input, _domain="com.microsoft")
Here, the replacement operation is created in the “com.microsoft” domain, which might provide optimized implementations of standard operations.
Complete rewrite example¶
def apply_rewrite(model):
# Create rewrite rules
relu_rule = pattern.RewriteRule(
custom_relu_pattern, # target pattern - matches custom domain operations
standard_relu_replacement, # replacement pattern - uses standard domain
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([relu_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite
This example shows how domain-specific pattern matching can be used to migrate operations between different operator domains, such as replacing custom domain operations with standard ONNX operations or vice versa.
Specifying outputs in the pattern¶
This section demonstrates the use of the _outputs
option in pattern-based rewriting.
The _outputs
option allows you to specify the number of outputs an operation produces
and optionally assign names to those outputs for easier reference in replacement patterns.
The _outputs
option can be specified in two ways:
As an integer:
_outputs=2
specifies that the operation produces 2 unnamed outputsAs a list of strings/None:
_outputs=["first", "second"]
specifies 2 named outputs
Matching operations with multiple outputs¶
def split_pattern(op, input):
# Pattern to match Split operations with 2 outputs
# num_outputs=2 corresponds to the attribute of the ONNX Split op
# _outputs=2 is an option controlling the pattern constructor
return op.Split(input, num_outputs=2, axis=0, _outputs=2)
This pattern matches Split
operations that produce exactly 2 outputs. The _outputs=2
specification ensures the pattern only matches operations with this specific output count.
Creating replacement operations with named outputs¶
def custom_split_replacement(op, input, **_):
# Replace with a custom split operation using named outputs
# _outputs=["first_half", "second_half"] assigns names to the outputs
# IMPORTANT: The number of outputs must match the pattern (2 outputs)
return op.CustomSplit(
input, _domain="custom.domain", _outputs=["first_half", "second_half"]
)
In the replacement, _outputs=["first_half", "second_half"]
creates two outputs with
descriptive names. This can make the replacement pattern more readable and maintainable.
Important: The number of outputs in the replacement pattern must match the number of
outputs in the target pattern. Since the pattern specifies _outputs=2
, the replacement
must also produce exactly 2 outputs.
Complete rewrite example¶
def apply_rewrite(model):
# Create rewrite rules
split_rule = pattern.RewriteRule(
split_pattern, # target pattern - matches Split with 2 outputs
custom_split_replacement, # replacement pattern - uses named outputs
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([split_rule])
# Apply rewrite
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite
The _outputs
option is particularly important when:
Working with operations that have variable numbers of outputs (like
Split
)Creating custom operations that need specific output configurations
Ensuring pattern matching precision by specifying exact output counts
Improving code readability by naming outputs in replacement patterns
Using the match_condition
parameter for pattern-matching¶
This section talks about how to utilize the match_condition
parameter. The match_condition
parameter checks if the pattern matches the target pattern with certain constraints in consideration.
Let us consider a model which consists of the following pattern.
Based on the ONNX Matmul spec, onnx Matmul
behaves like numpy.matmul
and also follows numpy broadcasting. So in this particular pattern if matmul broadcasting is enough, then we don’t need the reshapes. To validate this, we need to check the following:
Input shapes check:
input_a
andinput_b
should be broadcastableOutput shape check:
shape_c
should be the same as the output shape from thematmul(input_a, input_b)
If the above are true, then we don’t need the reshapes and we can eliminate them using a pattern based rewrite.
First, write a target pattern and replacement pattern in a similar way to the first example.
def two_reshapes_matmul_reshape_pattern(op, input_a, input_b, shape_a, shape_b, shape_c):
reshape_a = op.Reshape(input_a, shape_a)
reshape_b = op.Reshape(input_b, shape_b)
matmul = op.MatMul(reshape_a, reshape_b)
return op.Reshape(matmul, shape_c)
def matmul_pattern(op, input_a: ir.Value, input_b: ir.Value, **_):
return op.MatMul(input_a, input_b)
Note
The target pattern in this case has 5 inputs input_a
, input_b
, shape_a
, shape_b
, shape_c
. However, the replacement pattern only utilizes input_a
and input_b
. To avoid referencing all the unused parameters in the replacement pattern signature, pass only input_a
and input_b
and use **_
to represent all the unused parameters.
Similarly for writing the condition checking function, we require only input_a
, input_b
and shape_c
. Use **_
to represent all the unused parameters in the condition matching function signature.
In order to validate whether matmul broadcast is sufficient, we write a condition checking function as below.
Note that the relevant inputs passed to the check function are all instances of onnx_ir.Value
. These represent
the values in the input graph IR that matched against the corresponding pattern variables in the target
pattern. Please see documentation of the IR API for more details on how to use it, for example to identify
the type or shape or rank of these values.
def check_if_not_need_reshape(
context, input_a: ir.Value, input_b: ir.Value, shape_c: ir.Value, **_
) -> bool:
"""Condition to check if we need to replace the pattern.
If matmul broadcasting is enough, then we don't need the reshapes.
To validate this, we need to check the following:
1. Input shapes check: input_a and input_b should be broadcastable
2. Output shape check: shape_c should be the same as the output shape from the matmul(input_a, input_b)
If the above are true, then we don't need the reshapes.
Returns:
True if we need to replace the pattern, False otherwise.
"""
input_a_shape = input_a.shape
input_b_shape = input_b.shape
shape_c_tensor = shape_c.const_value
if shape_c_tensor is None:
logger.info("The value 'shape_c' is not statically known.")
return False
if len(shape_c_tensor.shape) != 1:
logger.info(
"Unexpected final shape. The shape of 'shape' value is %s",
shape_c_tensor.shape,
)
return False
# NOTE: When there is a subset match with a pattern. The MatchResult won't have the shape
# information. So, we need to check if the shape is None and return False.
if input_a_shape is None or input_b_shape is None:
logger.info("Shape information is not available for the inputs and outputs.")
return False
input_a_shape = input_a_shape.numpy()
input_b_shape = input_b_shape.numpy()
shape_c = shape_c_tensor.numpy().tolist()
a_rank = len(input_a_shape)
b_rank = len(input_b_shape)
# TODO(justinchuby): Check shape size
# 1. Check if input shapes are broadcastable
# 1.a. If the first input is 1-D, check whether
# the dim matches the last second dim of the second input.
mimic_matmul_broadcast_behavior = False
if a_rank < 2:
if b_rank < 2:
logger.info("Optimization of dot product is not supported yet.")
return False
if input_a_shape[-1] != input_b_shape[-2]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_a_shape = [1, *input_a_shape]
a_rank = len(input_a_shape)
mimic_matmul_broadcast_behavior = True
# 1.b. If the second input is 1-D, check whether
# the dim matches the last dim of the first input.
if b_rank < 2:
if input_b_shape[-1] != input_a_shape[-1]:
logger.info("Original shape is not MatMul compatible.")
return False
else:
input_b_shape = [*input_b_shape, 1]
b_rank = len(input_b_shape)
mimic_matmul_broadcast_behavior = True
# 1.c. If both inputs are at least 2-D, check whether
# the last dimension of the first input matches the second
# last dimension of the second input, and shape[:-2] are
# broadcastable.
input_a_shape_except_second_last_dim = [*input_a_shape[:-2], *[input_a_shape[-1]]]
input_b_shape_except_last_dim = input_b_shape[:-1]
broadcast_matmul_output_shape = [input_a_shape[-2], input_b_shape[-1]]
for idx, (dim_from_a, dim_from_b) in enumerate(
zip(
reversed(input_a_shape_except_second_last_dim),
reversed(input_b_shape_except_last_dim),
)
):
if dim_from_a not in {1, dim_from_b}:
logger.info("Original shape is not broadcastable.")
return False
elif idx > 0:
broadcast_matmul_output_shape = [
max(dim_from_a, dim_from_b),
*broadcast_matmul_output_shape,
]
# 2. Check if output shape is the same as the output shape from the matmul(input_a, input_b)
# Prepend the broadcast_matmul_output_shape with the longer shape of input
if a_rank > b_rank:
longer_shape = input_a_shape
shorter_shape = input_b_shape
else:
longer_shape = input_b_shape
shorter_shape = input_a_shape
broadcast_matmul_output_shape = [
*longer_shape[: -len(shorter_shape)],
*broadcast_matmul_output_shape,
]
if mimic_matmul_broadcast_behavior and b_rank == 2 and input_b_shape[-1] == 1:
# If input_b is expanded to 2-D, then we need to remove the last dimension
broadcast_matmul_output_shape = broadcast_matmul_output_shape[:-1]
if mimic_matmul_broadcast_behavior and a_rank == 2 and input_a_shape[0] == 1:
# If input_a is expanded to 2-D, then we need to remove the first dimension
# of input_a, which would be the -2nd dimension of the output shape.
broadcast_matmul_output_shape.pop(-2)
if shape_c != broadcast_matmul_output_shape:
logger.info(
"Final output shape is not the same. Expected %s vs actual %s",
shape_c,
broadcast_matmul_output_shape,
)
return False
return True
With all the necessary components in place, the pattern rewrite rule with the match_condition
function is created and then the rewriter.rewrite
is called to apply the rewrite.
def apply_rewrite(model):
# Create rewrite rules
two_reshapes_matmul_reshape_rule = pattern.RewriteRule(
two_reshapes_matmul_reshape_pattern, # target pattern
matmul_pattern, # replacement pattern
check_if_not_need_reshape, # match_condition function
)
# Create a Rewrite Rule Set
rewrite_rule_set = pattern.RewriteRuleSet([two_reshapes_matmul_reshape_rule])
# Apply rewrite while passing match_condition
model_with_rewrite = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite
The final graph with the applied rewrite looks as follows:
Using MatchContext for Advanced Condition Checking¶
The context
parameter passed to condition functions is an instance of onnxscript.rewriter.MatchContext
, which provides access to additional information about the pattern match that can be useful for sophisticated condition checking.
MatchContext Properties¶
The MatchContext provides the following read-only properties:
model
: The entire ONNX model being matchedgraph_or_function
: The specific graph or function being matchedroot
: The root node of the matching subgraphoutput_values
: The output values of the matching subgraphnodes
: All nodes that are part of the matching subgraph
Example Usage¶
Here’s an example showing how to use the MatchContext to implement more sophisticated condition checking:
def advanced_condition_check(context, x, y, **_):
"""Example condition function using MatchContext."""
# Access the main node of the pattern match
main_node = context.root
# Check that the main_node does not have an attribute called "alpha"
if "alpha" in main_node.attributes:
return False
# Access the broader graph context and check that x occurs as a graph-input
model = context.model
if x not in model.graph.inputs:
return False
# You can inspect the matched nodes for advanced validation
for node in context.nodes:
if node.op_type == "Constant":
# Check properties of constant nodes in the match
pass
# Access output values for shape/type validation
outputs = context.output_values
if len(outputs) > 0 and outputs[0].shape is not None:
# Validate output shapes
pass
return True
This context information enables condition functions to make decisions based on the broader graph structure, the specific nodes involved in the match, and relationships between matched patterns and the rest of the model.
OR Patterns¶
Note : This feature is work-in-progress.
Consider the following pattern:
def scaled_matmul(op, x, y, factor):
xy = op.MatMul(x, y)
choice1 = op.Mul(xy, factor)
choice2 = op.Div(xy, factor)
scaled_xy = pattern.OrValue(
[choice1, choice2], tag_var="op_type", tag_values=["Mul", "Div"]
)
return op.Relu(scaled_xy)
This pattern will successfully match against the sequence “MatMul => Mul => Relu” as
well as the sequence “MatMul => Div => Relu”. The matcher will bind the variable
specified in tag_var
(op_type
in the above example) to a value from those
listed in tag_values
to indicate which of the alternatives was used for a
successful match. We can use this in the rewrite function to determine how
we want to rewrite the matched sub-graph, as illustrated by the following code:
def scaled_matmul_replacement(op, x, y, factor, op_type):
if op_type == "Mul":
return op.MatMulMulRelu(x, y, factor, _domain="some.domain")
elif op_type == "Div":
return op.MatMulDivRelu(x, y, factor, _domain="some.domain")
else:
raise ValueError(f"Unknown operation type: {op_type}")
Utilizing commute
parameter for pattern-matching¶
Warning
Please note that the section below describes a convenience feature for handling commutative operators in pattern matching. However, the implementation is a simple, brute-force, technique that generates a collection of rewrite-rules from a given rule, taking commutativity of addition and multiplication into account. This can lead to an exponential increase in the number of rewrite-rules. So, it should be used with caution. Pattern disjunctions (OR Patterns) described earlier can be used judiciously to get a somewhat more efficient implementation in practice (even though the potential for exponential increase still exists within the pattern matching algorithm). Reimplementing commutativity handling using pattern disjunctions is future work.
Extending the previous simple example, assuming a scenario where we have a graph with the following structure.
In this graph, there exist two node pattern that constitute a GELU
op. However, there is a subtle difference between the two. Focusing on the parent Mul
nodes in either patterns, the order of the input values being multiplied is switched.
If we utilize the same target_pattern
created for the earlier simple example (shown below), only one of two GELU
pattern will be matched.
def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))

Only one of the patterns has been successfully matched and replaced by a GELU
node. In order to rewrite both the existing patterns in the graph, there are two methods.
1. Creating a rule-set with different patterns.¶
This method requires creating two separate rules and packing them into either a sequence of PatternRewriteRule
s or a RewriteRuleSet
. Creating a RewriteRuleSet
is the preferable option but either can be used. In order to create a RewriteRuleSet
with multiple rules rule1
and rule2
for example:
from onnxscript.rewriter import pattern
rewrite_rule_set = pattern.RewriteRuleSet(rules=[rule1, rule2])
In order to apply this method to the example above, first create the two separate target patterns as follows:
def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
def erf_gelu_pattern_2(op, x):
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
Note
When you pass multiple rules in pattern_rewrite_rules
, the order in which they appear is important.
This is because some rules may depend on patterns created or modified by earlier rules. For example, if rule2
can only match after rule1
has made a specific change in the model, then rule1
must come before rule2
in the list.
If you’re not seeing expected results, try adjusting the order or applying the rule set in a loop until no more changes occur.
Then, create two separate PatternRewriteRule
s, one for each target pattern. Pack these rules into a RewriteRuleSet
object and apply rewrites by passing the created RewriteRuleSet
for the pattern_rewrite_rules
parameter.
def apply_rewrite_with_ruleset(model):
# Create multiple rules
rule1 = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement
)
rule2 = pattern.RewriteRule(
erf_gelu_pattern_2, # Target Pattern
gelu, # Replacement
)
# Create a Rewrite Rule Set with multiple rules.
rewrite_rule_set = pattern.RewriteRuleSet([rule1, rule2])
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
# pattern_rewrite_rules=[rule1, rule2], # Alternative method of passing multiple rules
)
return model_with_rewrite_applied
2. Using the commute
parameter while creating a rule.¶
Creating multiple target patterns for similar patterns can be tedious. In order to avoid this, the commute
parameter can be utilized while creating the RewriteRuleSet
. Simply set commute=True
in order to avoid creating multiple target pattern for cases where patterns are different due to commutativity. Multiple rules with the different patterns emerging due to satisfying the commutativity property are automatically packed into a RewriteRuleSet
object. Then apply rewrites by passing the created RewriteRuleSet
for the pattern_rewrite_rules
parameter.
def apply_rewrite_with_commute(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement
)
# Create a Rewrite Rule Set with commute=True
rewrite_rule_set = pattern.RewriteRuleSet([rule], commute=True)
# Apply rewrites
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=rewrite_rule_set,
)
return model_with_rewrite_applied
For the both of the aforementioned methods, the final graph with both rewrites applied should look as follows:
Node and Value Level Checkers¶
The pattern matching infrastructure supports custom validation logic at both the node and value levels through checker functions. These checkers allow for more sophisticated pattern matching by enabling additional constraints beyond basic operator and structure matching.
Value-Level Checkers¶
Value-level checkers validate properties of specific values in the pattern. They are particularly useful for checking constants, shapes, or other value-specific properties.
Basic Usage¶
A value checker is a function that takes a MatchContext
and an ir.Value
, and returns either a boolean or a MatchResult
:
def is_positive_constant(context, value: ir.Value):
"""Check if a value is a positive constant."""
if value.const_value is not None:
# Get the numpy array from const_value
numpy_array = value.const_value.numpy()
# Check if it represents a single value and is positive
if numpy_array.size != 1:
return False
return float(numpy_array.item()) > 0
return False
You can use this checker directly in your pattern by passing the callable as an input:
def add_pattern(op, x, y):
# Use callable as input to create ValuePattern with checker
return op.Add(is_positive_constant, y)
This pattern will only match Add
operations where the first input is a positive constant value.
Example Usage¶
from onnxscript.rewriter import pattern
from onnxscript import ir, optimizer
import onnx
# Create a model with different Add operations
model_proto = onnx.parser.parse_model("""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N] x, float[N] y) => (float[N] z1, float[N] z2, float[N] z3)
{
pos_const = Constant <value_float = 2.5> ()
neg_const = Constant <value_float = -1.5> ()
z1 = Add(x, y) # non-constant first parameter
z2 = Add(pos_const, y) # positive constant first parameter
z3 = Add(neg_const, y) # negative constant first parameter
}
""")
model = ir.serde.deserialize_model(model_proto)
# Apply constant propagation to set const_value fields
optimizer.basic_constant_propagation(model.graph.all_nodes())
# Create the pattern with value checker
rule_pattern = pattern.Pattern(add_pattern)
# Test matching against different Add nodes
add_nodes = [node for node in model.graph if node.op_type == "Add"]
# Non-constant first parameter - will not match
match_result = rule_pattern.match(model, model.graph, add_nodes[0])
print(f"Non-constant: {bool(match_result)}") # False
# Positive constant first parameter - will match
match_result = rule_pattern.match(model, model.graph, add_nodes[1])
print(f"Positive constant: {bool(match_result)}") # True
# Negative constant first parameter - will not match
match_result = rule_pattern.match(model, model.graph, add_nodes[2])
print(f"Negative constant: {bool(match_result)}") # False
Node-Level Checkers¶
Node-level checkers validate properties of the operation nodes themselves, such as attributes, operation types, or other node-specific properties.
Basic Usage¶
A node checker is a function that takes a MatchContext
and an ir.Node
, and returns either a boolean or a MatchResult
:
def shape_node_checker(context, node):
"""Check if a Shape operation has start attribute equal to 0."""
return node.attributes.get_int("start", 0) == 0
You can use this checker by passing it to the _check
parameter of an operation:
def shape_pattern(op, x):
return op.Shape(x, _check=shape_node_checker)
This pattern will only match Shape
operations where the start
attribute is 0 (or not present, as the default is 0).
Example Usage¶
from onnxscript.rewriter import pattern
from onnxscript import ir
import onnx
# Create a model with different Shape operations
model_proto = onnx.parser.parse_model("""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[N, M] x) => (int64[2] z1, int64[2] z2, int64[1] z3)
{
z1 = Shape(x)
z2 = Shape <start: int = 0>(x)
z3 = Shape <start: int = 1>(x)
}
""")
model = ir.serde.deserialize_model(model_proto)
# Create the pattern with node checker
rule_pattern = pattern.Pattern(shape_pattern)
# Test matching against different Shape nodes
nodes = list(model.graph)
shape_nodes = [node for node in nodes if node.op_type == "Shape"]
# Shape without start attribute (default 0) - will match
match_result = rule_pattern.match(model, model.graph, shape_nodes[0])
print(f"No start attr: {bool(match_result)}") # True
# Shape with start=0 - will match
match_result = rule_pattern.match(model, model.graph, shape_nodes[1])
print(f"Start=0: {bool(match_result)}") # True
# Shape with start=1 - will not match
match_result = rule_pattern.match(model, model.graph, shape_nodes[2])
print(f"Start=1: {bool(match_result)}") # False
Combining Checkers¶
You can combine both node-level and value-level checkers in the same pattern for more sophisticated matching:
def complex_pattern(op, x, y):
# Value-level checker for first input
validated_x = is_positive_constant
# Node-level checker for the operation
return op.Add(validated_x, y, _check=lambda ctx, node: len(node.attributes) == 0)
This pattern will only match Add
operations where:
The first input is a positive constant (value-level check)
The node has no custom attributes (node-level check)
Execution Timing and Limitations¶
When Checkers Are Called¶
Node-level and value-level checkers are called only at the end of the complete structural match. This means:
Structural matching happens first: The pattern matching engine first validates that the graph structure matches the pattern (correct operators, connections, etc.)
Checkers run after structural validation: Only after the structural match succeeds do the node and value checkers execute
Order of execution: Value-level checkers run first, followed by node-level checkers, and finally the pattern’s condition function
Limitations with Pattern Disjunctions¶
One important limitation of this design is that these checks don’t compose well with pattern disjunctions (multiple alternative patterns). When searching among multiple value patterns:
Only structural checking is performed initially: If structural matching succeeds for the first alternative, other alternatives are not considered
Checker failures don’t trigger backtracking: If a checker fails, the entire pattern match fails rather than trying the next alternative pattern
This means you should be careful when designing patterns with multiple alternatives that rely on checkers, as the checker logic may prevent exploration of valid alternative matches.
Error Handling¶
Checkers can return either:
True
: Check passed, continue matchingFalse
: Check failed, pattern does not matchMatchResult
: More detailed result with potential failure reasons
If a checker raises an exception, it will be caught and treated as a match failure, allowing patterns to fail gracefully when encountering unexpected conditions.