A Simple ExampleΒΆ
An simple example demonstrating the usage of this functionality using the GELU
activation function:
GELU
activation function can be computed using a Gauss Error Function using the given formula:
We will show how we can find a subgraph matching this computation and replace it by a call to the function.
Firstly, include all the rewriter relevant imports.
from onnxscript.rewriter import pattern
from onnxscript import ir
Then create a target pattern that needs to be replaced using onnxscript operators.
def erf_gelu_pattern(op, x):
return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))
After this, create a replacement pattern that consists of the GELU onnxscript operator.
def gelu(op, x: ir.Value):
return op.Gelu(x, _domain="com.microsoft")
Note
The inputs to the replacement pattern are of type ir.Value
. For detailed usage of ir.Value
refer to the ir.Value
class.
For this example, we do not require a match_condition
so that option is skipped for now. Then the rewrite rule is created using the RewriteRule
function.
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement Pattern
)
It is more convenient to organize more complex rewrite-rules as a class. The above rule can be alternatively expressed as below.
class ErfGeluFusion(pattern.RewriteRuleClassBase):
def pattern(self, op, x):
return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5
def rewrite(self, op, x):
return op.Gelu(x, _domain="com.microsoft")
The corresponding rewrite-rule can be obtained as below:
erf_gelu_rule_from_class = ErfGeluFusion.rule()
Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The rewriter.rewrite (model, pattern_rewrite_rules)
call applies the specified rewrite rules to the given model.
model
: The original model on which the pattern rewrite rules are to be applied. This is of typeir.Model
oronnx.ModelProto
. If the model is anir.Model
, the rewriter applies the changes in-place, modifying the input model. If it is anModelProto
, the rewriter returns a newModelProto
representing the transformed model.pattern_rewrite_rules
: This parameter either aSequence[PatternRewriteRule]
or aRewriteRuleSet
.
Note
For steps on how to create and use Rule-sets, refer to the example in the section Creating a rule-set with different patterns.
The snippet below below demonstrates how to use the rewriter.rewrite
call for the rewrite rule created above:
def apply_rewrite(model):
rule = pattern.RewriteRule(
erf_gelu_pattern, # Target Pattern
gelu, # Replacement
)
model_with_rewrite_applied = onnxscript.rewriter.rewrite(
model,
pattern_rewrite_rules=[rule],
)
return model_with_rewrite_applied
The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended.