A Simple ExampleΒΆ

An simple example demonstrating the usage of this functionality using the GELU activation function:

GELU activation function can be computed using a Gauss Error Function using the given formula:

\[\text{GELU} = x\Phi(x) = x \cdot \frac{1}{2} [1 + \text{erf}(x / \sqrt{2})]\]

We will show how we can find a subgraph matching this computation and replace it by a call to the function.

Firstly, include all the rewriter relevant imports.

from onnxscript.rewriter import pattern
from onnxscript import ir

Then create a target pattern that needs to be replaced using onnxscript operators.

def erf_gelu_pattern(op, x):
    return 0.5 * (x * (op.Erf(x / math.sqrt(2)) + 1.0))

After this, create a replacement pattern that consists of the GELU onnxscript operator.

def gelu(op, x: ir.Value):
    return op.Gelu(x, _domain="com.microsoft")

Note

The inputs to the replacement pattern are of type ir.Value. For detailed usage of ir.Value refer to the ir.Value class.

For this example, we do not require a match_condition so that option is skipped for now. Then the rewrite rule is created using the RewriteRule function.

rule = pattern.RewriteRule(
    erf_gelu_pattern,  # Target Pattern
    gelu,  # Replacement Pattern
)

It is more convenient to organize more complex rewrite-rules as a class. The above rule can be alternatively expressed as below.

class ErfGeluFusion(pattern.RewriteRuleClassBase):
    def pattern(self, op, x):
        return (x * (op.Erf(x / math.sqrt(2)) + 1.0)) * 0.5

    def rewrite(self, op, x):
        return op.Gelu(x, _domain="com.microsoft")

The corresponding rewrite-rule can be obtained as below:

erf_gelu_rule_from_class = ErfGeluFusion.rule()

Now that the rewrite rule has been created, the next step is to apply these pattern-based rewrite rules. The rewriter.rewrite (model, pattern_rewrite_rules) call applies the specified rewrite rules to the given model.

  1. model : The original model on which the pattern rewrite rules are to be applied. This is of type ir.Model or onnx.ModelProto. If the model is an ir.Model, the rewriter applies the changes in-place, modifying the input model. If it is an ModelProto, the rewriter returns a new ModelProto representing the transformed model.

  2. pattern_rewrite_rules : This parameter either a Sequence[PatternRewriteRule] or a RewriteRuleSet.

Note

For steps on how to create and use Rule-sets, refer to the example in the section Creating a rule-set with different patterns.

The snippet below below demonstrates how to use the rewriter.rewrite call for the rewrite rule created above:

def apply_rewrite(model):
    rule = pattern.RewriteRule(
        erf_gelu_pattern,  # Target Pattern
        gelu,  # Replacement
    )
    model_with_rewrite_applied = onnxscript.rewriter.rewrite(
        model,
        pattern_rewrite_rules=[rule],
    )
    return model_with_rewrite_applied

The graph (on the left) consists of the target pattern before the rewrite rule is applied. Once the rewrite rule is applied, the graph (on the right) shows that the target pattern has been successfully replaced by a GELU node as intended.

target_pattern replacement_pattern