RewriteRuleClassBase¶
- class onnxscript.rewriter.pattern.RewriteRuleClassBase(name: str | None = None, remove_nodes: bool = True, as_function: bool = False)[source]¶
Base class for implementing rewrite rules as a class.
Example:
class TransposeIdentity(RewriteRuleClassBase): def pattern(cls, op, x, perm): return op.Transpose(x, perm=perm) def check(cls, context, x: ir.Value, perm: ir.Attr) -> bool: if perm.is_ref(): return False if perm.type == ir.AttributeType.INTS: if list(perm.as_ints()) == list(range(len(perm.as_ints()))): return True return False def rewrite(cls, op, x: ir.Value, perm: ir.Attr | None = None): return op.Identity(x) # Then use # TransposeIdentity.rule() # to create a RewriteRule object.