Source code for onnxscript.rewriter._matcher

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Implementation of the pattern matching algorithm."""

from __future__ import annotations

import abc
import itertools
import math
from typing import (
    Iterable,
    Sequence,
)

import onnxscript.rewriter._basics as _basics
import onnxscript.rewriter._pattern_ir as _pattern_ir
from onnxscript import ir


def _valid_to_replace(
    matched_nodes: Sequence[ir.Node], output_values: Sequence[ir.Value]
) -> bool:
    """Check that values computed by the matched_nodes, except for output_values, are used only by the matched_nodes."""
    # * Must check that all values matched by pattern are used only by pattern,
    # except for the value that is replaced.
    # * Must ensure that replacement subgraph does not use any of the deleted
    # (intermediate) values. (Not necessary for now. Guaranteed.)
    for n in matched_nodes:
        for v in n.outputs:
            if v in output_values:
                continue
            if v.is_graph_output():
                # value is an output-value of the graph/function.
                return False
            for consumer, _ in v.uses():
                if consumer not in matched_nodes:
                    return False
    return True


[docs] class PatternMatcher(abc.ABC): def __init__(self, pattern: _pattern_ir.GraphPattern) -> None: self.pattern = pattern
[docs] @abc.abstractmethod def match( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, *, verbose: int = 0, remove_nodes: bool = True, tracer: _basics.MatchingTracer | None = None, ) -> _basics.MatchResult: """Match the pattern against the subgraph ending at the given node."""
def __str__(self) -> str: return str(self.pattern)
[docs] class SimplePatternMatcher(PatternMatcher): def __init__(self, pattern: _pattern_ir.GraphPattern) -> None: super().__init__(pattern) self._current_node: ir.Node | None = None
[docs] def fail(self, reason: str, node: ir.Node | None = None) -> bool: if self._verbose: num_matched_nodes = self._match.num_matched_nodes() if num_matched_nodes > 0: # Print only if at least one node successfully matched. print(f"Match failed after {num_matched_nodes} nodes: {reason}") self._match.fail(reason, node or self._current_node) return False
def _match_constant(self, pattern_constant: _pattern_ir.Constant, value: ir.Value) -> bool: """Match a Constant pattern against a value. If the constant value is produced by a Constant node, we do not include the constant node as part of the matched graph. Thus, it will not be deleted, if subgraph replacement happens. But subsequent DCE will remove the constant node if it is not used elsewhere. """ constant_value = value.const_value if constant_value is None: return self.fail( f"Value {value.name} is not a constant, expecting {pattern_constant.value}.", ) try: constant_value_numpy = constant_value.numpy() except FileNotFoundError: return self.fail(f"Constant value of {value.name} not available.") pattern_constant_value = pattern_constant._value if isinstance(pattern_constant_value, list): expected_shape = (len(pattern_constant_value),) if constant_value_numpy.shape != expected_shape: return self.fail(f"Value has mismatching shape, expecting {expected_shape}.") if not all( math.isclose( constant_value_numpy.item(i), pattern_constant_value[i], rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, ) for i in range(len(pattern_constant_value)) ): return self.fail( f"Value mismatch: expected {pattern_constant_value}, got {constant_value_numpy}." ) return True # TODO (rama): allow users to specify shape requirement, if desired. if constant_value_numpy.size != 1: return self.fail( f"Value {value.name} is not a scalar, expecting {pattern_constant_value}.", ) if not math.isclose( constant_value_numpy.item(), pattern_constant_value, rel_tol=pattern_constant._rel_tol, abs_tol=pattern_constant._abs_tol, ): return self.fail( f"Constant value mismatch: expected {pattern_constant_value}, got {constant_value_numpy.item()}.", ) return True def _match_node(self, pattern_node: _pattern_ir.NodePattern, node: ir.Node) -> bool: """Matches a pattern subgraph against subgraph rooted at node.""" self._current_node = node # Graph-matching: we do not allow the same pattern node to be matched against # different graph nodes. matched_node = self._match.lookup_node(pattern_node) if matched_node is not None: if matched_node is not node: return self.fail("Same pattern node is matched against different graph nodes.") return True match = self._match if not pattern_node.matches(node, match): return self.fail(match.reason) if self._verbose: print(f"Matched: {node.op_type}") match.bind_node(pattern_node, node) # TODO: Revisit this to handle optional trailing inputs better. if len(node.inputs) > len(pattern_node.inputs): if not pattern_node.allow_other_inputs: return self.fail( f"Number of inputs ({len(node.inputs)}) is greater than expected ({len(pattern_node.inputs)})" ) checked_inputs = zip(node.inputs, pattern_node.inputs) else: # In ONNX, trailing Nones can be omitted in the inputs of a node. So, we extend actual # node inputs with None values to match the pattern node inputs length when zipping. checked_inputs = itertools.zip_longest( node.inputs, pattern_node.inputs, fillvalue=None ) for arg_value, arg_pattern in checked_inputs: # arg_pattern could be a Var, if it's the original arg. if arg_pattern is None: if arg_value is None: continue else: return self.fail("(Optional) input is expected to be None but is not.") if not self._match_value(arg_pattern, arg_value): return False for i, output_value_pattern in enumerate(pattern_node.outputs): # When trying to bind more outputs (from the pattern) than there are # actual outputs of the candidate node, reject the node before even # trying to index into the list of node outputs. if i >= len(node.outputs): return False if not self._match.bind_value(output_value_pattern, node.outputs[i]): return False return True def _match_value( self, pattern_value: _pattern_ir.ValuePattern, value: ir.Value | None ) -> bool: """Match an IR value against a ValuePattern instance.""" if value is not None and value.graph is not self._graph: if not isinstance( pattern_value, (_pattern_ir.Var, _pattern_ir.Constant, _pattern_ir.AnyValue) ): # If the pattern value is a Var, Constant, or AnyValue, we allow it to match # values from other graphs. Otherwise, we fail the match. return self.fail( f"Value {value.name} is not in the graph {self._graph.name}. " f"Pattern matches crossing graph boundaries are not supported." ) if isinstance(pattern_value, _pattern_ir.AnyValue): return True if not self._match.bind_value(pattern_value, value): return False if isinstance(pattern_value, _pattern_ir.NodeOutputPattern): if value is None: return self.fail("Mismatch: Computed node pattern does not match None.") return self._match_node_output(pattern_value, value) if isinstance(pattern_value, _pattern_ir.Constant): if value is None: return self.fail("Mismatch: Constant pattern does not match None.") return self._match_constant(pattern_value, value) if isinstance(pattern_value, _pattern_ir.BacktrackingOr): for i, pattern_choice in enumerate(pattern_value._values): self._match.enter_new_match() if self._match_value(pattern_choice, value): if pattern_value.tag_var is not None: self._match.bind(pattern_value.tag_var, pattern_value._tag_values[i]) self._match.merge_current_match() return True self._match.abandon_current_match() return self.fail("None of the alternatives matched.") if isinstance(pattern_value, _pattern_ir.OpIdDispatchOr): if value is None: return self.fail("Mismatch: OrValue pattern does not match None.") alternative = pattern_value.get_pattern(value) if alternative is None: return self.fail("Mismatch: OrValue pattern does not match value.") i, pattern_choice = alternative result = self._match_value(pattern_choice, value) if result: if pattern_value.tag_var is not None: self._match.bind(pattern_value.tag_var, i) return result # Default case: a plain pattern variable (ValuePattern) if value is None and not pattern_value.can_match_none: return self.fail( f"Mismatch: pattern variable {pattern_value} does not match None." ) return True def _match_node_output( self, pattern_value: _pattern_ir.NodeOutputPattern, value: ir.Value ) -> bool: """Match an IR value against a NodeOutputPattern instance.""" node = value.producer() if node is None: return self.fail( "Mismatch: Computed node pattern does not match uncomputed IR value." ) if value.index() != pattern_value.output_index: return self.fail( f"Node output index mismatch: expected {pattern_value._output_index}, got {value.index()}." ) return self._match_node(pattern_value.producer(), node) def _init_match(self, verbose: int) -> None: """Initialize the match state. Invoked before starting a new match.""" self._verbose = verbose self._match: _basics.MatchResult = _basics.MatchResult() self._current_node = None def _get_output_values(self) -> list[ir.Value] | None: """Get values bound to the output variables of the pattern.""" output_values: list[ir.Value] = [] unbound_values: list[str] = [] for j, value_pattern in enumerate(self.pattern.outputs): if value_pattern.name is not None: if value_pattern.name in self._match.bindings: output_values.append(self._match.bindings[value_pattern.name]) else: unbound_values.append(value_pattern.name) else: if value_pattern in self._match.value_bindings: output_values.append(self._match.value_bindings[value_pattern]) else: unbound_values.append(f"output_{j}") if unbound_values: self._match.fail(f"Error: Output values not found: {unbound_values}") return None return output_values def _match_single_output_node( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, check_removable: bool, ) -> _basics.MatchResult: del model del graph_or_function pattern = self.pattern match = self._match if not pattern.has_single_output_node: return match.fail( "Internal Error: SimplePatternMatcher should not be used for patterns with multiple output nodes." ) if not self._match_node(pattern.output_node, node): return match output_values = self._get_output_values() if output_values is None: # TODO(rama): Is this a valid (useful) case? return match if check_removable and not _valid_to_replace(match.nodes, output_values): # TODO(rama): Match status should be updated to reflect failure reason. return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) return match def _multi_match( self, candidate: Iterable[ir.Node], check_removable: bool ) -> _basics.MatchResult: """Find a match for a pattern with multiple output nodes. For a pattern with K output nodes, the input candidate should specify K nodes in the graph that will be matched against the pattern output nodes. Args: candidate: An iterable of nodes that will be matched against the pattern output nodes. check_removable: If True, check that the matched nodes can be removed (that is, that they are not used elsewhere in the graph). """ match = self._match for pattern_node, node in zip(self.pattern.output_nodes, candidate): if not self._match_node(pattern_node, node): return match output_values = self._get_output_values() if output_values is None: return match if check_removable and not _valid_to_replace(match.nodes, output_values): return match.fail("Matched nodes have other uses preventing replacement.") match.outputs.extend(output_values) return match
[docs] def match( self, model: ir.Model, graph_or_function: ir.Graph | ir.Function, node: ir.Node, *, verbose: int = 0, remove_nodes: bool = True, tracer: _basics.MatchingTracer | None = None, ) -> _basics.MatchResult: """Match the pattern against the subgraph ending at the given node. For patterns with multiple output nodes, the given node is matched against the first output node in the pattern. For the remaining output nodes in the pattern, we use a brute-force algorithm that enumerates all possible combinations of nodes from the graph (with a filter based on op-type). TODO: Consider omitting parameters model and graph_or_function. With the new IR, the graph can be obtained from the node, and the model is not used. But this is a shared abstract method of the Matcher interface, so other matcher implementation also needs to be updated. More importantly, matching in the presence of subgraphs (control-flow) can introduce some complications which require careful consideration. """ self._tracer = tracer if isinstance(graph_or_function, ir.Graph): self._graph: ir.Graph = graph_or_function else: self._graph = graph_or_function.graph if self.pattern.has_single_output_node: self._init_match(verbose) return self._match_single_output_node( model, graph_or_function, node, check_removable=remove_nodes ) else: # Note: This is a potentially expensive algorithm for matching patterns with # multiple output nodes. For patterns with N output nodes, we try all possible # combinations of N nodes from the graph, and check if they match the pattern. # The first node is fixed to the node argument in this method call. We do # some simple filtering by restricting the candidates for each remaining # output nodes to graph nodes with the same op_type as the corresponding pattern # node. For now, this is intended to be a simple, but robust, implementation # that can be used for debugging and testing. The GenericPatternMatcher is a # more sophisticated implementation, but incomplete. pattern_output_nodes = self.pattern.output_nodes op_to_nodes: dict[tuple[str, str, str], list[ir.Node]] = {} for n in graph_or_function: op_to_nodes.setdefault(n.op_identifier(), []).append(n) all_nodes = iter(graph_or_function) def get_nodes(pattern_node): id = pattern_node.op_identifier() if id is None: return all_nodes return op_to_nodes.get(id, []) candidates = [iter([node])] + [get_nodes(pn) for pn in pattern_output_nodes[1:]] match = None for combination in itertools.product(*candidates): self._init_match(verbose) match = self._multi_match(combination, check_removable=remove_nodes) if match: return match if match is None: return _basics.MatchResult().fail("No match found.") return match