提交 2d46d60e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove unused rewrites and functionality

上级 40ccab1a
...@@ -134,7 +134,7 @@ computation graph. ...@@ -134,7 +134,7 @@ computation graph.
In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`, In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`,
and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with
another while respecting certain validation constraints. As an another while respecting certain validation constraints. As an
exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you exercise, try to rewrite :class:`Simplify` using :class:`WalkingGraphRewriter`. (Hint: you
want to use the method it publishes instead of the call to toposort) want to use the method it publishes instead of the call to toposort)
Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
......
...@@ -26,7 +26,3 @@ Guide ...@@ -26,7 +26,3 @@ Guide
.. class:: ReplaceValidate(History, Validator) .. class:: ReplaceValidate(History, Validator)
.. method:: replace_validate(fgraph, var, new_var, reason=None) .. method:: replace_validate(fgraph, var, new_var, reason=None)
.. class:: NodeFinder(Bookkeeper)
.. class:: PrintListener(object)
...@@ -827,100 +827,6 @@ class ReplaceValidate(History, Validator): ...@@ -827,100 +827,6 @@ class ReplaceValidate(History, Validator):
raise InconsistencyError("Trying to reintroduce a removed node") raise InconsistencyError("Trying to reintroduce a removed node")
class NodeFinder(Bookkeeper):
def __init__(self):
self.fgraph = None
self.d = {}
def on_attach(self, fgraph):
if hasattr(fgraph, "get_nodes"):
raise AlreadyThere("NodeFinder is already present")
if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.")
self.fgraph = fgraph
fgraph.get_nodes = partial(self.query, fgraph)
Bookkeeper.on_attach(self, fgraph)
def clone(self):
return type(self)()
def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.fgraph is not fgraph:
raise Exception(
"This NodeFinder instance was not attached to the provided fgraph."
)
self.fgraph = None
del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node, reason):
try:
self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
return
except Exception as e:
print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201
try:
print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201
except Exception:
print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201
raise e
def on_prune(self, fgraph, node, reason):
try:
nodes = self.d[node.op]
except TypeError: # node.op is unhashable
return
nodes.remove(node)
if not nodes:
del self.d[node.op]
def query(self, fgraph, op):
try:
all = self.d.get(op, [])
except TypeError:
raise TypeError(
f"{op} in unhashable and cannot be queried by the optimizer"
)
all = list(all)
return all
class PrintListener(Feature):
def __init__(self, active=True):
self.active = active
def on_attach(self, fgraph):
if self.active:
print("-- attaching to: ", fgraph) # noqa: T201
def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.active:
print("-- detaching from: ", fgraph) # noqa: T201
def on_import(self, fgraph, node, reason):
if self.active:
print(f"-- importing: {node}, reason: {reason}") # noqa: T201
def on_prune(self, fgraph, node, reason):
if self.active:
print(f"-- pruning: {node}, reason: {reason}") # noqa: T201
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active:
print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") # noqa: T201
class PreserveVariableAttributes(Feature): class PreserveVariableAttributes(Feature):
""" """
This preserve some variables attributes and tag during optimization. This preserve some variables attributes and tag during optimization.
......
...@@ -11,8 +11,7 @@ import traceback ...@@ -11,8 +11,7 @@ import traceback
import warnings import warnings
from collections import Counter, UserList, defaultdict, deque from collections import Counter, UserList, defaultdict, deque
from collections.abc import Callable, Iterable, Sequence from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable as IterableType from functools import _compose_mro, partial # type: ignore
from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal
...@@ -28,7 +27,7 @@ from pytensor.graph.basic import ( ...@@ -28,7 +27,7 @@ from pytensor.graph.basic import (
io_toposort, io_toposort,
vars_between, vars_between,
) )
from pytensor.graph.features import AlreadyThere, Feature, NodeFinder from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.utils import AssocList, InconsistencyError from pytensor.graph.utils import AssocList, InconsistencyError
...@@ -60,14 +59,6 @@ FailureCallbackType = Callable[ ...@@ -60,14 +59,6 @@ FailureCallbackType = Callable[
] ]
class MetaNodeRewriterSkip(AssertionError):
"""This is an `AssertionError`, but instead of having the
`MetaNodeRewriter` print the error, it just skip that
compilation.
"""
class Rewriter(abc.ABC): class Rewriter(abc.ABC):
"""Abstract base class for graph/term rewriters.""" """Abstract base class for graph/term rewriters."""
...@@ -942,129 +933,6 @@ def pre_constant_merge(fgraph, variables): ...@@ -942,129 +933,6 @@ def pre_constant_merge(fgraph, variables):
return [recursive_merge(v) for v in variables] return [recursive_merge(v) for v in variables]
class MetaNodeRewriter(NodeRewriter):
r"""
Base class for meta-rewriters that try a set of `NodeRewriter`\s
to replace a node and choose the one that executes the fastest.
If the error `MetaNodeRewriterSkip` is raised during
compilation, we will skip that function compilation and not print
the error.
"""
def __init__(self):
self.verbose = config.metaopt__verbose
self.track_dict = defaultdict(list)
self.tag_dict = defaultdict(list)
self._tracks = []
self.rewriters = []
def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]):
self.rewriters.append(rewriter)
tracks = rewriter.tracks()
if tracks:
self._tracks.extend(tracks)
for c in tracks:
self.track_dict[c].append(rewriter)
for tag in tag_list:
self.tag_dict[tag].append(rewriter)
def tracks(self):
return self._tracks
def transform(self, fgraph, node, *args, **kwargs):
# safety check: depending on registration, tracks may have been ignored
if self._tracks is not None:
if not isinstance(node.op, tuple(self._tracks)):
return
# first, we need to provide dummy values for all inputs
# to the node that are not shared variables anyway
givens = {}
missing = set()
for input in node.inputs:
if isinstance(input, pytensor.compile.SharedVariable):
pass
elif hasattr(input.tag, "test_value"):
givens[input] = pytensor.shared(
input.type.filter(input.tag.test_value),
input.name,
shape=input.broadcastable,
borrow=True,
)
else:
missing.add(input)
if missing:
givens.update(self.provide_inputs(node, missing))
missing.difference_update(givens.keys())
# ensure we have data for all input variables that need it
if missing:
if self.verbose > 0:
print( # noqa: T201
f"{self.__class__.__name__} cannot meta-rewrite {node}, "
f"{len(missing)} of {int(node.nin)} input shapes unknown"
)
return
# now we can apply the different rewrites in turn,
# compile the resulting subgraphs and time their execution
if self.verbose > 1:
print( # noqa: T201
f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):"
)
timings = []
for node_rewriter in self.get_rewrites(node):
outputs = node_rewriter.transform(fgraph, node, *args, **kwargs)
if outputs:
try:
fn = pytensor.function(
[], outputs, givens=givens, on_unused_input="ignore"
)
fn.trust_input = True
timing = min(self.time_call(fn) for _ in range(2))
except MetaNodeRewriterSkip:
continue
except Exception as e:
if self.verbose > 0:
print(f"* {node_rewriter}: exception", e) # noqa: T201
continue
else:
if self.verbose > 1:
print(f"* {node_rewriter}: {timing:.5g} sec") # noqa: T201
timings.append((timing, outputs, node_rewriter))
else:
if self.verbose > 0:
print(f"* {node_rewriter}: not applicable") # noqa: T201
# finally, we choose the fastest one
if timings:
timings.sort()
if self.verbose > 1:
print(f"= {timings[0][2]}") # noqa: T201
return timings[0][1]
return
def provide_inputs(self, node, inputs):
"""Return a dictionary mapping some `inputs` to `SharedVariable` instances of with dummy values.
The `node` argument can be inspected to infer required input shapes.
"""
raise NotImplementedError()
def get_rewrites(self, node):
"""Return the rewrites that apply to `node`.
This uses ``self.track_dict[type(node.op)]`` by default.
"""
return self.track_dict[type(node.op)]
def time_call(self, fn):
start = time.perf_counter()
fn()
return time.perf_counter() - start
class FromFunctionNodeRewriter(NodeRewriter): class FromFunctionNodeRewriter(NodeRewriter):
"""A `NodeRewriter` constructed from a function.""" """A `NodeRewriter` constructed from a function."""
...@@ -1214,9 +1082,6 @@ class SequentialNodeRewriter(NodeRewriter): ...@@ -1214,9 +1082,6 @@ class SequentialNodeRewriter(NodeRewriter):
reentrant : bool reentrant : bool
Some global rewriters, like `NodeProcessingGraphRewriter`, use this value to Some global rewriters, like `NodeProcessingGraphRewriter`, use this value to
determine if they should ignore new nodes. determine if they should ignore new nodes.
retains_inputs : bool
States whether or not the inputs of a transformed node are transferred
to the outputs.
""" """
def __init__( def __init__(
...@@ -1247,9 +1112,6 @@ class SequentialNodeRewriter(NodeRewriter): ...@@ -1247,9 +1112,6 @@ class SequentialNodeRewriter(NodeRewriter):
self.reentrant = any( self.reentrant = any(
getattr(rewrite, "reentrant", True) for rewrite in rewriters getattr(rewrite, "reentrant", True) for rewrite in rewriters
) )
self.retains_inputs = all(
getattr(rewrite, "retains_inputs", False) for rewrite in rewriters
)
self.apply_all_rewrites = apply_all_rewrites self.apply_all_rewrites = apply_all_rewrites
...@@ -1425,17 +1287,12 @@ class SubstitutionNodeRewriter(NodeRewriter): ...@@ -1425,17 +1287,12 @@ class SubstitutionNodeRewriter(NodeRewriter):
# an SubstitutionNodeRewriter does not apply to the nodes it produces # an SubstitutionNodeRewriter does not apply to the nodes it produces
reentrant = False reentrant = False
# all the inputs of the original node are transferred to the outputs
retains_inputs = True
def __init__(self, op1, op2, transfer_tags=True): def __init__(self, op1, op2, transfer_tags=True):
self.op1 = op1 self.op1 = op1
self.op2 = op2 self.op2 = op2
self.transfer_tags = transfer_tags self.transfer_tags = transfer_tags
def op_key(self):
return self.op1
def tracks(self): def tracks(self):
return [self.op1] return [self.op1]
...@@ -1453,39 +1310,6 @@ class SubstitutionNodeRewriter(NodeRewriter): ...@@ -1453,39 +1310,6 @@ class SubstitutionNodeRewriter(NodeRewriter):
return f"{self.op1} -> {self.op2}" return f"{self.op1} -> {self.op2}"
class RemovalNodeRewriter(NodeRewriter):
"""
Removes all applications of an `Op` by transferring each of its
outputs to the corresponding input.
"""
reentrant = False # no nodes are added at all
def __init__(self, op):
self.op = op
def op_key(self):
return self.op
def tracks(self):
return [self.op]
def transform(self, fgraph, node):
if node.op != self.op:
return False
return node.inputs
def __str__(self):
return f"{self.op}(x) -> x"
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(
f"{' ' * level}{self.__class__.__name__}(self.op) id={id(self)}",
file=stream,
)
class PatternNodeRewriter(NodeRewriter): class PatternNodeRewriter(NodeRewriter):
"""Replace all occurrences of an input pattern with an output pattern. """Replace all occurrences of an input pattern with an output pattern.
...@@ -1545,7 +1369,6 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1545,7 +1369,6 @@ class PatternNodeRewriter(NodeRewriter):
in_pattern, in_pattern,
out_pattern, out_pattern,
allow_multiple_clients: bool = False, allow_multiple_clients: bool = False,
skip_identities_fn=None,
name: str | None = None, name: str | None = None,
tracks=(), tracks=(),
get_nodes=None, get_nodes=None,
...@@ -1563,8 +1386,6 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1563,8 +1386,6 @@ class PatternNodeRewriter(NodeRewriter):
allow_multiple_clients allow_multiple_clients
If ``False``, the pattern matching will fail if one of the subpatterns has If ``False``, the pattern matching will fail if one of the subpatterns has
more than one client. more than one client.
skip_identities_fn
TODO
name name
Set the name of this rewriter. Set the name of this rewriter.
tracks tracks
...@@ -1574,15 +1395,15 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1574,15 +1395,15 @@ class PatternNodeRewriter(NodeRewriter):
function that takes the tracked node and returns a list of nodes on function that takes the tracked node and returns a list of nodes on
which we will try this rewrite. which we will try this rewrite.
values_eq_approx values_eq_approx
TODO If specified, this value will be assigned to the ``values_eq_approx``
tag of the output variable. This is used by DebugMode to determine if rewrites are correct.
allow_cast allow_cast
Automatically cast the output of the rewrite whenever new and old types differ Automatically cast the output of the rewrite whenever new and old types differ
Notes Notes
----- -----
`tracks` and `get_nodes` can be used to make this rewrite track a less `tracks` and `get_nodes` can be used to make this rewrite track a less
frequent `Op`, which will prevent the rewrite from being tried as frequent `Op`, which will prevent the rewrite from being tried as often.
often.
""" """
from pytensor.graph.rewriting.unify import convert_strs_to_vars from pytensor.graph.rewriting.unify import convert_strs_to_vars
...@@ -1600,9 +1421,7 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1600,9 +1421,7 @@ class PatternNodeRewriter(NodeRewriter):
raise TypeError( raise TypeError(
"The pattern to search for must start with a specific Op instance." "The pattern to search for must start with a specific Op instance."
) )
self.__doc__ = f"{self.__class__.__doc__}\n\nThis instance does: {self}\n"
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn
if name: if name:
self.__name__ = name self.__name__ = name
self._tracks = tracks self._tracks = tracks
...@@ -1610,9 +1429,6 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1610,9 +1429,6 @@ class PatternNodeRewriter(NodeRewriter):
if tracks != (): if tracks != ():
assert get_nodes assert get_nodes
def op_key(self):
return self.op
def tracks(self): def tracks(self):
if self._tracks != (): if self._tracks != ():
return self._tracks return self._tracks
...@@ -2136,7 +1952,7 @@ def walking_rewriter( ...@@ -2136,7 +1952,7 @@ def walking_rewriter(
else: else:
(node_rewriters,) = node_rewriters (node_rewriters,) = node_rewriters
if not name: if not name:
name = node_rewriters.__name__ name = getattr(node_rewriters, "__name__", None)
ret = WalkingGraphRewriter( ret = WalkingGraphRewriter(
node_rewriters, node_rewriters,
order=order, order=order,
...@@ -2152,52 +1968,6 @@ in2out = partial(walking_rewriter, "in_to_out") ...@@ -2152,52 +1968,6 @@ in2out = partial(walking_rewriter, "in_to_out")
out2in = partial(walking_rewriter, "out_to_in") out2in = partial(walking_rewriter, "out_to_in")
class OpKeyGraphRewriter(NodeProcessingGraphRewriter):
r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s.
The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
as a list of `Op`\s or a single `Op`), and discovered within a
`FunctionGraph` using the `NodeFinder` `Feature`.
This is similar to the `Op`-based tracking feature used by other rewriters.
"""
def __init__(self, node_rewriter, ignore_newtrees=False, failure_callback=None):
if not hasattr(node_rewriter, "op_key"):
raise TypeError(f"{node_rewriter} must have an `op_key` method.")
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
def apply(self, fgraph):
op = self.node_rewriter.op_key()
if isinstance(op, list | tuple):
q = reduce(list.__iadd__, map(fgraph.get_nodes, op))
else:
q = list(fgraph.get_nodes(op))
def importer(node):
if node is not current_node:
if node.op == op:
q.append(node)
u = self.attach_updater(
fgraph, importer, None, name=getattr(self, "name", None)
)
try:
while q:
node = q.pop()
if node not in fgraph.apply_nodes:
continue
current_node = node
self.process_node(fgraph, node)
finally:
self.detach_updater(fgraph, u)
def add_requirements(self, fgraph):
super().add_requirements(fgraph)
fgraph.attach_feature(NodeFinder())
class ChangeTracker(Feature): class ChangeTracker(Feature):
def __init__(self): def __init__(self):
self.changed = False self.changed = False
...@@ -2785,38 +2555,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2785,38 +2555,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
) )
def _check_chain(r, chain):
"""
WRITEME
"""
chain = list(reversed(chain))
while chain:
elem = chain.pop()
if elem is None:
if r.owner is not None:
return False
elif r.owner is None:
return False
elif isinstance(elem, Op):
if r.owner.op != elem:
return False
else:
try:
if issubclass(elem, Op) and not isinstance(r.owner.op, elem):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
# print 'check_chain', _check_chain.n_calls
# _check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance)
return r is not None
def pre_greedy_node_rewriter( def pre_greedy_node_rewriter(
fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable
) -> Variable: ) -> Variable:
......
...@@ -34,7 +34,6 @@ from pytensor.graph.basic import Constant ...@@ -34,7 +34,6 @@ from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
NodeRewriter, NodeRewriter,
RemovalNodeRewriter,
Rewriter, Rewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
...@@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node): ...@@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node):
return [alloc(inputs_inner[0], *dims_outer)] return [alloc(inputs_inner[0], *dims_outer)]
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy") @register_canonicalize
@node_rewriter(tracks=[tensor_copy])
def remove_tensor_copy(fgraph, node):
return node.inputs
@register_specialize @register_specialize
......
...@@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10): ...@@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
return np.allclose(x, ref, rtol=rtol, atol=atol) return np.allclose(x, ref, rtol=rtol, atol=atol)
def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
if len(not_is_1) == 1:
return not_is_1[0]
def _is_1(expr): def _is_1(expr):
""" """
...@@ -3190,7 +3183,6 @@ logsigm_to_softplus = PatternNodeRewriter( ...@@ -3190,7 +3183,6 @@ logsigm_to_softplus = PatternNodeRewriter(
(neg, (softplus, (neg, "x"))), (neg, (softplus, (neg, "x"))),
allow_multiple_clients=True, allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
tracks=[sigmoid], tracks=[sigmoid],
get_nodes=get_clients_at_depth1, get_nodes=get_clients_at_depth1,
) )
...@@ -3199,7 +3191,6 @@ log1msigm_to_softplus = PatternNodeRewriter( ...@@ -3199,7 +3191,6 @@ log1msigm_to_softplus = PatternNodeRewriter(
(neg, (softplus, "x")), (neg, (softplus, "x")),
allow_multiple_clients=True, allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
tracks=[sigmoid], tracks=[sigmoid],
get_nodes=get_clients_at_depth2, get_nodes=get_clients_at_depth2,
) )
......
...@@ -13,7 +13,7 @@ from pytensor.compile.io import In, Out ...@@ -13,7 +13,7 @@ from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker from pytensor.link.vm import VMLinker
from pytensor.printing import debugprint from pytensor.printing import debugprint
...@@ -39,7 +39,7 @@ pytestmark = pytest.mark.filterwarnings("error") ...@@ -39,7 +39,7 @@ pytestmark = pytest.mark.filterwarnings("error")
def PatternOptimizer(p1, p2, ign=True): def PatternOptimizer(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestFunction: class TestFunction:
......
...@@ -8,11 +8,9 @@ from pytensor.graph.op import Op ...@@ -8,11 +8,9 @@ from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter, EquilibriumGraphRewriter,
MergeOptimizer, MergeOptimizer,
OpKeyGraphRewriter,
OpToRewriterTracker, OpToRewriterTracker,
PatternNodeRewriter, PatternNodeRewriter,
SequentialNodeRewriter, SequentialNodeRewriter,
SubstitutionNodeRewriter,
WalkingGraphRewriter, WalkingGraphRewriter,
in2out, in2out,
logging, logging,
...@@ -51,33 +49,29 @@ class AssertNoChanges(Feature): ...@@ -51,33 +49,29 @@ class AssertNoChanges(Feature):
raise AssertionError() raise AssertionError()
def OpKeyPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False): def WalkingPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False):
return OpKeyGraphRewriter( return WalkingGraphRewriter(
PatternNodeRewriter(p1, p2, allow_multiple_clients=allow_multiple_clients), PatternNodeRewriter(p1, p2, allow_multiple_clients=allow_multiple_clients),
ignore_newtrees=ign, ignore_newtrees=ign,
) )
def WalkingPatternNodeRewriter(p1, p2, ign=True):
return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestPatternNodeRewriter: class TestPatternNodeRewriter:
def test_replace_output(self): def test_replace_output(self):
# replacing the whole graph # replacing the whole graph
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).rewrite( WalkingPatternNodeRewriter(
g (op1, (op2, "1", "2"), "3"), (op4, "3", "2")
) ).rewrite(g)
assert str(g) == "FunctionGraph(Op4(z, y))" assert str(g) == "FunctionGraph(Op4(z, y))"
def test_nested_out_pattern(self): def test_nested_out_pattern(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(x, y) e = op1(x, y)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2")) (op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2"))
).rewrite(g) ).rewrite(g)
assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))" assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
...@@ -86,7 +80,7 @@ class TestPatternNodeRewriter: ...@@ -86,7 +80,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, x), z) # the arguments to op2 are the same e = op1(op2(x, x), z) # the arguments to op2 are the same
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op1, (op2, "1", "1"), "2"), # they are the same in the pattern
(op4, "2", "1"), (op4, "2", "1"),
).rewrite(g) ).rewrite(g)
...@@ -97,7 +91,7 @@ class TestPatternNodeRewriter: ...@@ -97,7 +91,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z) # the arguments to op2 are different e = op1(op2(x, y), z) # the arguments to op2 are different
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op1, (op2, "1", "1"), "2"), # they are the same in the pattern (op1, (op2, "1", "1"), "2"), # they are the same in the pattern
(op4, "2", "1"), (op4, "2", "1"),
).rewrite(g) ).rewrite(g)
...@@ -109,7 +103,7 @@ class TestPatternNodeRewriter: ...@@ -109,7 +103,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op2, "1", "2"), (op1, "2", "1")).rewrite(g) WalkingPatternNodeRewriter((op2, "1", "2"), (op1, "2", "1")).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))" assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
def test_no_recurse(self): def test_no_recurse(self):
...@@ -119,7 +113,9 @@ class TestPatternNodeRewriter: ...@@ -119,7 +113,9 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z) e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(g) WalkingPatternNodeRewriter((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(
g
)
assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))" assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))"
def test_multiple(self): def test_multiple(self):
...@@ -127,7 +123,7 @@ class TestPatternNodeRewriter: ...@@ -127,7 +123,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), op2(x, y), op2(y, z)) e = op1(op2(x, y), op2(x, y), op2(y, z))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op2, "1", "2"), (op4, "1")).rewrite(g) WalkingPatternNodeRewriter((op2, "1", "2"), (op4, "1")).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))" assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def test_nested_even(self): def test_nested_even(self):
...@@ -136,21 +132,21 @@ class TestPatternNodeRewriter: ...@@ -136,21 +132,21 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(x)))) e = op1(op1(op1(op1(x))))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) WalkingPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g)
assert str(g) == "FunctionGraph(x)" assert str(g) == "FunctionGraph(x)"
def test_nested_odd(self): def test_nested_odd(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x))))) e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g) WalkingPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g)
assert str(g) == "FunctionGraph(Op1(x))" assert str(g) == "FunctionGraph(Op1(x))"
def test_expand(self): def test_expand(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(x))) e = op1(op1(op1(x)))
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g) WalkingPatternNodeRewriter((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))" assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
def test_ambiguous(self): def test_ambiguous(self):
...@@ -169,7 +165,7 @@ class TestPatternNodeRewriter: ...@@ -169,7 +165,7 @@ class TestPatternNodeRewriter:
z = Constant(MyType(), 2, name="z") z = Constant(MyType(), 2, name="z")
e = op1(op1(x, y), y) e = op1(op1(x, y), y)
g = FunctionGraph([y], [e]) g = FunctionGraph([y], [e])
OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g) WalkingPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(y, z{2}), y))" assert str(g) == "FunctionGraph(Op1(Op2(y, z{2}), y))"
def test_constraints(self): def test_constraints(self):
...@@ -181,7 +177,7 @@ class TestPatternNodeRewriter: ...@@ -181,7 +177,7 @@ class TestPatternNodeRewriter:
# Only replacing if the input is an instance of Op2 # Only replacing if the input is an instance of Op2
return r.owner.op == op2 return r.owner.op == op2
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op1, {"pattern": "1", "constraint": constraint}), (op3, "1") (op1, {"pattern": "1", "constraint": constraint}), (op3, "1")
).rewrite(g) ).rewrite(g)
assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))" assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
...@@ -190,7 +186,7 @@ class TestPatternNodeRewriter: ...@@ -190,7 +186,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(x, x) e = op1(x, x)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, "x", "y"), (op3, "x", "y")).rewrite(g) WalkingPatternNodeRewriter((op1, "x", "y"), (op3, "x", "y")).rewrite(g)
assert str(g) == "FunctionGraph(Op3(x, x))" assert str(g) == "FunctionGraph(Op3(x, x))"
@pytest.mark.xfail( @pytest.mark.xfail(
...@@ -202,10 +198,10 @@ class TestPatternNodeRewriter: ...@@ -202,10 +198,10 @@ class TestPatternNodeRewriter:
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
def constraint(r): def constraint(r):
# Only replacing if the input is an instance of Op2 # Only replacing if the inputs are not identical
return r.owner.inputs[0] is not r.owner.inputs[1] return r.owner.inputs[0] is not r.owner.inputs[1]
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
{"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y") {"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y")
).rewrite(g) ).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))" assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
...@@ -220,7 +216,7 @@ class TestPatternNodeRewriter: ...@@ -220,7 +216,7 @@ class TestPatternNodeRewriter:
# So the replacement should fail # So the replacement should fail
outputs = [e] outputs = [e]
g = FunctionGraph(inputs, outputs, copy_inputs=False) g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op4, (op1, "x", "y")), (op4, (op1, "x", "y")),
(op3, "x", "y"), (op3, "x", "y"),
).rewrite(g) ).rewrite(g)
...@@ -228,7 +224,7 @@ class TestPatternNodeRewriter: ...@@ -228,7 +224,7 @@ class TestPatternNodeRewriter:
# Now it should be fine # Now it should be fine
g = FunctionGraph(inputs, outputs, copy_inputs=False) g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op4, (op1, "x", "y")), (op4, (op1, "x", "y")),
(op3, "x", "y"), (op3, "x", "y"),
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -237,7 +233,7 @@ class TestPatternNodeRewriter: ...@@ -237,7 +233,7 @@ class TestPatternNodeRewriter:
# The fact that the inputs of the pattern have multiple clients should not matter # The fact that the inputs of the pattern have multiple clients should not matter
g = FunctionGraph(inputs, outputs, copy_inputs=False) g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op3, (op4, "w"), "w"), (op3, (op4, "w"), "w"),
(op3, "w", "w"), (op3, "w", "w"),
allow_multiple_clients=False, allow_multiple_clients=False,
...@@ -252,7 +248,7 @@ class TestPatternNodeRewriter: ...@@ -252,7 +248,7 @@ class TestPatternNodeRewriter:
outputs = [e1, e2] outputs = [e1, e2]
g = FunctionGraph(inputs, outputs, copy_inputs=False) g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op4, (op4, "e")), (op4, (op4, "e")),
"e", "e",
allow_multiple_clients=False, allow_multiple_clients=False,
...@@ -261,7 +257,7 @@ class TestPatternNodeRewriter: ...@@ -261,7 +257,7 @@ class TestPatternNodeRewriter:
outputs = [e1, e3] outputs = [e1, e3]
g = FunctionGraph([x, y, z], outputs, copy_inputs=False) g = FunctionGraph([x, y, z], outputs, copy_inputs=False)
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op4, (op4, "e")), (op4, (op4, "e")),
"e", "e",
allow_multiple_clients=False, allow_multiple_clients=False,
...@@ -269,7 +265,7 @@ class TestPatternNodeRewriter: ...@@ -269,7 +265,7 @@ class TestPatternNodeRewriter:
assert equal_computations(g.outputs, outputs) assert equal_computations(g.outputs, outputs)
g = FunctionGraph(inputs, outputs, copy_inputs=False) g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter( WalkingPatternNodeRewriter(
(op4, (op4, "e")), (op4, (op4, "e")),
"e", "e",
allow_multiple_clients=True, allow_multiple_clients=True,
...@@ -281,33 +277,13 @@ class TestPatternNodeRewriter: ...@@ -281,33 +277,13 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op_y(x, y), z) e = op1(op_y(x, y), z)
g = FunctionGraph([x, y, z], [e]) g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).rewrite( WalkingPatternNodeRewriter(
g (op1, (op_z, "1", "2"), "3"), (op4, "3", "2")
) ).rewrite(g)
str_g = str(g) str_g = str(g)
assert str_g == "FunctionGraph(Op4(z, y))" assert str_g == "FunctionGraph(Op4(z, y))"
def KeyedSubstitutionNodeRewriter(op1, op2):
return OpKeyGraphRewriter(SubstitutionNodeRewriter(op1, op2))
class TestSubstitutionNodeRewriter:
def test_straightforward(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
KeyedSubstitutionNodeRewriter(op1, op2).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
def test_straightforward_2(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x), op3(y), op4(z))
g = FunctionGraph([x, y, z], [e])
KeyedSubstitutionNodeRewriter(op3, op4).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
class NoInputOp(Op): class NoInputOp(Op):
__props__ = ("param",) __props__ = ("param",)
......
...@@ -10,7 +10,6 @@ from pytensor.graph.fg import FunctionGraph ...@@ -10,7 +10,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
OpKeyGraphRewriter,
PatternNodeRewriter, PatternNodeRewriter,
SubstitutionNodeRewriter, SubstitutionNodeRewriter,
WalkingGraphRewriter, WalkingGraphRewriter,
...@@ -21,7 +20,7 @@ from tests.unittest_tools import assertFailure_fast ...@@ -21,7 +20,7 @@ from tests.unittest_tools import assertFailure_fast
def OpKeyPatternNodeRewriter(p1, p2, ign=True): def OpKeyPatternNodeRewriter(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def TopoSubstitutionNodeRewriter( def TopoSubstitutionNodeRewriter(
......
...@@ -2,92 +2,12 @@ import pytest ...@@ -2,92 +2,12 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph import rewrite_graph from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.basic import equal_computations
from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate from pytensor.graph.features import Feature, FullHistory, ReplaceValidate
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from tests.graph.utils import MyVariable, op1 from tests.graph.utils import MyVariable, op1
class TestNodeFinder:
def test_straightforward(self):
class MyType(Type):
def __init__(self, name):
self.name = name
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
return isinstance(other, MyType)
class MyOp(Op):
__props__ = ("nin", "name")
def __init__(self, nin, name):
self.nin = nin
self.name = name
def make_node(self, *inputs):
def as_variable(x):
assert isinstance(x, Variable)
return x
assert len(inputs) == self.nin
inputs = list(map(as_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType(self.name + "_R")()]
return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def perform(self, *args, **kwargs):
raise NotImplementedError()
sigmoid = MyOp(1, "Sigmoid")
add = MyOp(2, "Add")
dot = MyOp(2, "Dot")
def MyVariable(name):
return Variable(MyType(name), None, None)
def inputs():
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
return x, y, z
x, y, z = inputs()
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = FunctionGraph([x, y, z], [e], clone=False)
g.attach_feature(NodeFinder())
assert hasattr(g, "get_nodes")
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if len(list(g.get_nodes(type))) != num:
raise Exception(f"Expected: {num} times {type}")
new_e0 = add(y, z)
assert e0.owner in g.get_nodes(dot)
assert new_e0.owner not in g.get_nodes(add)
g.replace(e0, new_e0)
assert e0.owner not in g.get_nodes(dot)
assert new_e0.owner in g.get_nodes(add)
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if len(list(g.get_nodes(type))) != num:
raise Exception(f"Expected: {num} times {type}")
class TestReplaceValidate: class TestReplaceValidate:
def test_verbose(self, capsys): def test_verbose(self, capsys):
var1 = MyVariable("var1") var1 = MyVariable("var1")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论