RewriteRuleClassBase

class onnxscript.rewriter.pattern.RewriteRuleClassBase(name: str | None = None, remove_nodes: bool = True, as_function: bool = False)[source]

Base class for implementing rewrite rules as a class.

Example:

class TransposeIdentity(RewriteRuleClassBase):
    def pattern(cls, op, x, perm):
        return op.Transpose(x, perm=perm)

    def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool:
        if perm.is_ref():
            return False
        if perm.type == ir.AttributeType.INTS:
            if list(perm.as_ints()) == list(range(len(perm.as_ints()))):
                return True
        return False

    def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None):
        return op.Identity(x)

# Then use
# TransposeIdentity.rule()
# to create a RewriteRule object.
classmethod rule(*args, **kwargs)[source]
abstract rewrite(op, *args, **kwargs)[source]
setup()[source]

Optional setup function that can be overridden by derived classes.

Used to do per model/function initialization.

cleanup()[source]

Optional cleanup function that can be overridden by derived classes.

Used to do per model/function cleanup.