提交 2d46d60e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove unused rewrites and functionality

上级 40ccab1a
......@@ -134,7 +134,7 @@ computation graph.
In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`,
and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with
another while respecting certain validation constraints. As an
exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you
exercise, try to rewrite :class:`Simplify` using :class:`WalkingGraphRewriter`. (Hint: you
want to use the method it publishes instead of the call to toposort)
Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
......
......@@ -26,7 +26,3 @@ Guide
.. class:: ReplaceValidate(History, Validator)
.. method:: replace_validate(fgraph, var, new_var, reason=None)
.. class:: NodeFinder(Bookkeeper)
.. class:: PrintListener(object)
......@@ -827,100 +827,6 @@ class ReplaceValidate(History, Validator):
raise InconsistencyError("Trying to reintroduce a removed node")
class NodeFinder(Bookkeeper):
def __init__(self):
self.fgraph = None
self.d = {}
def on_attach(self, fgraph):
if hasattr(fgraph, "get_nodes"):
raise AlreadyThere("NodeFinder is already present")
if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.")
self.fgraph = fgraph
fgraph.get_nodes = partial(self.query, fgraph)
Bookkeeper.on_attach(self, fgraph)
def clone(self):
return type(self)()
def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.fgraph is not fgraph:
raise Exception(
"This NodeFinder instance was not attached to the provided fgraph."
)
self.fgraph = None
del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node, reason):
try:
self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
return
except Exception as e:
print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201
try:
print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201
except Exception:
print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201
raise e
def on_prune(self, fgraph, node, reason):
try:
nodes = self.d[node.op]
except TypeError: # node.op is unhashable
return
nodes.remove(node)
if not nodes:
del self.d[node.op]
def query(self, fgraph, op):
try:
all = self.d.get(op, [])
except TypeError:
raise TypeError(
f"{op} in unhashable and cannot be queried by the optimizer"
)
all = list(all)
return all
class PrintListener(Feature):
def __init__(self, active=True):
self.active = active
def on_attach(self, fgraph):
if self.active:
print("-- attaching to: ", fgraph) # noqa: T201
def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.active:
print("-- detaching from: ", fgraph) # noqa: T201
def on_import(self, fgraph, node, reason):
if self.active:
print(f"-- importing: {node}, reason: {reason}") # noqa: T201
def on_prune(self, fgraph, node, reason):
if self.active:
print(f"-- pruning: {node}, reason: {reason}") # noqa: T201
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active:
print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") # noqa: T201
class PreserveVariableAttributes(Feature):
"""
This preserve some variables attributes and tag during optimization.
......
......@@ -11,8 +11,7 @@ import traceback
import warnings
from collections import Counter, UserList, defaultdict, deque
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable as IterableType
from functools import _compose_mro, partial, reduce # type: ignore
from functools import _compose_mro, partial # type: ignore
from itertools import chain
from typing import TYPE_CHECKING, Literal
......@@ -28,7 +27,7 @@ from pytensor.graph.basic import (
io_toposort,
vars_between,
)
from pytensor.graph.features import AlreadyThere, Feature, NodeFinder
from pytensor.graph.features import AlreadyThere, Feature
from pytensor.graph.fg import FunctionGraph, Output
from pytensor.graph.op import Op
from pytensor.graph.utils import AssocList, InconsistencyError
......@@ -60,14 +59,6 @@ FailureCallbackType = Callable[
]
class MetaNodeRewriterSkip(AssertionError):
"""This is an `AssertionError`, but instead of having the
`MetaNodeRewriter` print the error, it just skip that
compilation.
"""
class Rewriter(abc.ABC):
"""Abstract base class for graph/term rewriters."""
......@@ -942,129 +933,6 @@ def pre_constant_merge(fgraph, variables):
return [recursive_merge(v) for v in variables]
class MetaNodeRewriter(NodeRewriter):
r"""
Base class for meta-rewriters that try a set of `NodeRewriter`\s
to replace a node and choose the one that executes the fastest.
If the error `MetaNodeRewriterSkip` is raised during
compilation, we will skip that function compilation and not print
the error.
"""
def __init__(self):
self.verbose = config.metaopt__verbose
self.track_dict = defaultdict(list)
self.tag_dict = defaultdict(list)
self._tracks = []
self.rewriters = []
def register(self, rewriter: NodeRewriter, tag_list: IterableType[str]):
self.rewriters.append(rewriter)
tracks = rewriter.tracks()
if tracks:
self._tracks.extend(tracks)
for c in tracks:
self.track_dict[c].append(rewriter)
for tag in tag_list:
self.tag_dict[tag].append(rewriter)
def tracks(self):
return self._tracks
def transform(self, fgraph, node, *args, **kwargs):
# safety check: depending on registration, tracks may have been ignored
if self._tracks is not None:
if not isinstance(node.op, tuple(self._tracks)):
return
# first, we need to provide dummy values for all inputs
# to the node that are not shared variables anyway
givens = {}
missing = set()
for input in node.inputs:
if isinstance(input, pytensor.compile.SharedVariable):
pass
elif hasattr(input.tag, "test_value"):
givens[input] = pytensor.shared(
input.type.filter(input.tag.test_value),
input.name,
shape=input.broadcastable,
borrow=True,
)
else:
missing.add(input)
if missing:
givens.update(self.provide_inputs(node, missing))
missing.difference_update(givens.keys())
# ensure we have data for all input variables that need it
if missing:
if self.verbose > 0:
print( # noqa: T201
f"{self.__class__.__name__} cannot meta-rewrite {node}, "
f"{len(missing)} of {int(node.nin)} input shapes unknown"
)
return
# now we can apply the different rewrites in turn,
# compile the resulting subgraphs and time their execution
if self.verbose > 1:
print( # noqa: T201
f"{self.__class__.__name__} meta-rewriting {node} ({len(self.get_rewrites(node))} choices):"
)
timings = []
for node_rewriter in self.get_rewrites(node):
outputs = node_rewriter.transform(fgraph, node, *args, **kwargs)
if outputs:
try:
fn = pytensor.function(
[], outputs, givens=givens, on_unused_input="ignore"
)
fn.trust_input = True
timing = min(self.time_call(fn) for _ in range(2))
except MetaNodeRewriterSkip:
continue
except Exception as e:
if self.verbose > 0:
print(f"* {node_rewriter}: exception", e) # noqa: T201
continue
else:
if self.verbose > 1:
print(f"* {node_rewriter}: {timing:.5g} sec") # noqa: T201
timings.append((timing, outputs, node_rewriter))
else:
if self.verbose > 0:
print(f"* {node_rewriter}: not applicable") # noqa: T201
# finally, we choose the fastest one
if timings:
timings.sort()
if self.verbose > 1:
print(f"= {timings[0][2]}") # noqa: T201
return timings[0][1]
return
def provide_inputs(self, node, inputs):
"""Return a dictionary mapping some `inputs` to `SharedVariable` instances of with dummy values.
The `node` argument can be inspected to infer required input shapes.
"""
raise NotImplementedError()
def get_rewrites(self, node):
"""Return the rewrites that apply to `node`.
This uses ``self.track_dict[type(node.op)]`` by default.
"""
return self.track_dict[type(node.op)]
def time_call(self, fn):
start = time.perf_counter()
fn()
return time.perf_counter() - start
class FromFunctionNodeRewriter(NodeRewriter):
"""A `NodeRewriter` constructed from a function."""
......@@ -1214,9 +1082,6 @@ class SequentialNodeRewriter(NodeRewriter):
reentrant : bool
Some global rewriters, like `NodeProcessingGraphRewriter`, use this value to
determine if they should ignore new nodes.
retains_inputs : bool
States whether or not the inputs of a transformed node are transferred
to the outputs.
"""
def __init__(
......@@ -1247,9 +1112,6 @@ class SequentialNodeRewriter(NodeRewriter):
self.reentrant = any(
getattr(rewrite, "reentrant", True) for rewrite in rewriters
)
self.retains_inputs = all(
getattr(rewrite, "retains_inputs", False) for rewrite in rewriters
)
self.apply_all_rewrites = apply_all_rewrites
......@@ -1425,17 +1287,12 @@ class SubstitutionNodeRewriter(NodeRewriter):
# an SubstitutionNodeRewriter does not apply to the nodes it produces
reentrant = False
# all the inputs of the original node are transferred to the outputs
retains_inputs = True
def __init__(self, op1, op2, transfer_tags=True):
self.op1 = op1
self.op2 = op2
self.transfer_tags = transfer_tags
def op_key(self):
return self.op1
def tracks(self):
return [self.op1]
......@@ -1453,39 +1310,6 @@ class SubstitutionNodeRewriter(NodeRewriter):
return f"{self.op1} -> {self.op2}"
class RemovalNodeRewriter(NodeRewriter):
"""
Removes all applications of an `Op` by transferring each of its
outputs to the corresponding input.
"""
reentrant = False # no nodes are added at all
def __init__(self, op):
self.op = op
def op_key(self):
return self.op
def tracks(self):
return [self.op]
def transform(self, fgraph, node):
if node.op != self.op:
return False
return node.inputs
def __str__(self):
return f"{self.op}(x) -> x"
def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(
f"{' ' * level}{self.__class__.__name__}(self.op) id={id(self)}",
file=stream,
)
class PatternNodeRewriter(NodeRewriter):
"""Replace all occurrences of an input pattern with an output pattern.
......@@ -1545,7 +1369,6 @@ class PatternNodeRewriter(NodeRewriter):
in_pattern,
out_pattern,
allow_multiple_clients: bool = False,
skip_identities_fn=None,
name: str | None = None,
tracks=(),
get_nodes=None,
......@@ -1563,8 +1386,6 @@ class PatternNodeRewriter(NodeRewriter):
allow_multiple_clients
If ``False``, the pattern matching will fail if one of the subpatterns has
more than one client.
skip_identities_fn
TODO
name
Set the name of this rewriter.
tracks
......@@ -1574,15 +1395,15 @@ class PatternNodeRewriter(NodeRewriter):
function that takes the tracked node and returns a list of nodes on
which we will try this rewrite.
values_eq_approx
TODO
If specified, this value will be assigned to the ``values_eq_approx``
tag of the output variable. This is used by DebugMode to determine if rewrites are correct.
allow_cast
Automatically cast the output of the rewrite whenever new and old types differ
Notes
-----
`tracks` and `get_nodes` can be used to make this rewrite track a less
frequent `Op`, which will prevent the rewrite from being tried as
often.
frequent `Op`, which will prevent the rewrite from being tried as often.
"""
from pytensor.graph.rewriting.unify import convert_strs_to_vars
......@@ -1600,9 +1421,7 @@ class PatternNodeRewriter(NodeRewriter):
raise TypeError(
"The pattern to search for must start with a specific Op instance."
)
self.__doc__ = f"{self.__class__.__doc__}\n\nThis instance does: {self}\n"
self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn
if name:
self.__name__ = name
self._tracks = tracks
......@@ -1610,9 +1429,6 @@ class PatternNodeRewriter(NodeRewriter):
if tracks != ():
assert get_nodes
def op_key(self):
return self.op
def tracks(self):
if self._tracks != ():
return self._tracks
......@@ -2136,7 +1952,7 @@ def walking_rewriter(
else:
(node_rewriters,) = node_rewriters
if not name:
name = node_rewriters.__name__
name = getattr(node_rewriters, "__name__", None)
ret = WalkingGraphRewriter(
node_rewriters,
order=order,
......@@ -2152,52 +1968,6 @@ in2out = partial(walking_rewriter, "in_to_out")
out2in = partial(walking_rewriter, "out_to_in")
class OpKeyGraphRewriter(NodeProcessingGraphRewriter):
r"""A rewriter that applies a `NodeRewriter` to specific `Op`\s.
The `Op`\s are provided by a :meth:`NodeRewriter.op_key` method (either
as a list of `Op`\s or a single `Op`), and discovered within a
`FunctionGraph` using the `NodeFinder` `Feature`.
This is similar to the `Op`-based tracking feature used by other rewriters.
"""
def __init__(self, node_rewriter, ignore_newtrees=False, failure_callback=None):
if not hasattr(node_rewriter, "op_key"):
raise TypeError(f"{node_rewriter} must have an `op_key` method.")
super().__init__(node_rewriter, ignore_newtrees, failure_callback)
def apply(self, fgraph):
op = self.node_rewriter.op_key()
if isinstance(op, list | tuple):
q = reduce(list.__iadd__, map(fgraph.get_nodes, op))
else:
q = list(fgraph.get_nodes(op))
def importer(node):
if node is not current_node:
if node.op == op:
q.append(node)
u = self.attach_updater(
fgraph, importer, None, name=getattr(self, "name", None)
)
try:
while q:
node = q.pop()
if node not in fgraph.apply_nodes:
continue
current_node = node
self.process_node(fgraph, node)
finally:
self.detach_updater(fgraph, u)
def add_requirements(self, fgraph):
super().add_requirements(fgraph)
fgraph.attach_feature(NodeFinder())
class ChangeTracker(Feature):
def __init__(self):
self.changed = False
......@@ -2785,38 +2555,6 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
)
def _check_chain(r, chain):
"""
WRITEME
"""
chain = list(reversed(chain))
while chain:
elem = chain.pop()
if elem is None:
if r.owner is not None:
return False
elif r.owner is None:
return False
elif isinstance(elem, Op):
if r.owner.op != elem:
return False
else:
try:
if issubclass(elem, Op) and not isinstance(r.owner.op, elem):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
# print 'check_chain', _check_chain.n_calls
# _check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance)
return r is not None
def pre_greedy_node_rewriter(
fgraph: FunctionGraph, rewrites: Sequence[NodeRewriter], out: Variable
) -> Variable:
......
......@@ -34,7 +34,6 @@ from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
NodeRewriter,
RemovalNodeRewriter,
Rewriter,
copy_stack_trace,
in2out,
......@@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node):
return [alloc(inputs_inner[0], *dims_outer)]
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy")
@register_canonicalize
@node_rewriter(tracks=[tensor_copy])
def remove_tensor_copy(fgraph, node):
return node.inputs
@register_specialize
......
......@@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
return np.allclose(x, ref, rtol=rtol, atol=atol)
def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
if len(not_is_1) == 1:
return not_is_1[0]
def _is_1(expr):
"""
......@@ -3190,7 +3183,6 @@ logsigm_to_softplus = PatternNodeRewriter(
(neg, (softplus, (neg, "x"))),
allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
tracks=[sigmoid],
get_nodes=get_clients_at_depth1,
)
......@@ -3199,7 +3191,6 @@ log1msigm_to_softplus = PatternNodeRewriter(
(neg, (softplus, "x")),
allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
tracks=[sigmoid],
get_nodes=get_clients_at_depth2,
)
......
......@@ -13,7 +13,7 @@ from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter
from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker
from pytensor.printing import debugprint
......@@ -39,7 +39,7 @@ pytestmark = pytest.mark.filterwarnings("error")
def PatternOptimizer(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestFunction:
......
......@@ -8,11 +8,9 @@ from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import (
EquilibriumGraphRewriter,
MergeOptimizer,
OpKeyGraphRewriter,
OpToRewriterTracker,
PatternNodeRewriter,
SequentialNodeRewriter,
SubstitutionNodeRewriter,
WalkingGraphRewriter,
in2out,
logging,
......@@ -51,33 +49,29 @@ class AssertNoChanges(Feature):
raise AssertionError()
def OpKeyPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False):
return OpKeyGraphRewriter(
def WalkingPatternNodeRewriter(p1, p2, allow_multiple_clients=False, ign=False):
return WalkingGraphRewriter(
PatternNodeRewriter(p1, p2, allow_multiple_clients=allow_multiple_clients),
ignore_newtrees=ign,
)
def WalkingPatternNodeRewriter(p1, p2, ign=True):
return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestPatternNodeRewriter:
def test_replace_output(self):
# replacing the whole graph
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, (op2, "1", "2"), "3"), (op4, "3", "2")).rewrite(
g
)
WalkingPatternNodeRewriter(
(op1, (op2, "1", "2"), "3"), (op4, "3", "2")
).rewrite(g)
assert str(g) == "FunctionGraph(Op4(z, y))"
def test_nested_out_pattern(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(x, y)
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op1, "1", "2"), (op4, (op1, "1"), (op2, "2"), (op3, "1", "2"))
).rewrite(g)
assert str(g) == "FunctionGraph(Op4(Op1(x), Op2(y), Op3(x, y)))"
......@@ -86,7 +80,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, x), z) # the arguments to op2 are the same
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op1, (op2, "1", "1"), "2"), # they are the same in the pattern
(op4, "2", "1"),
).rewrite(g)
......@@ -97,7 +91,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z) # the arguments to op2 are different
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op1, (op2, "1", "1"), "2"), # they are the same in the pattern
(op4, "2", "1"),
).rewrite(g)
......@@ -109,7 +103,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op2, "1", "2"), (op1, "2", "1")).rewrite(g)
WalkingPatternNodeRewriter((op2, "1", "2"), (op1, "2", "1")).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op1(y, x), z))"
def test_no_recurse(self):
......@@ -119,7 +113,9 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), z)
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(g)
WalkingPatternNodeRewriter((op2, "1", "2"), (op2, "2", "1"), ign=True).rewrite(
g
)
assert str(g) == "FunctionGraph(Op1(Op2(y, x), z))"
def test_multiple(self):
......@@ -127,7 +123,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x, y), op2(x, y), op2(y, z))
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op2, "1", "2"), (op4, "1")).rewrite(g)
WalkingPatternNodeRewriter((op2, "1", "2"), (op4, "1")).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op4(x), Op4(x), Op4(y)))"
def test_nested_even(self):
......@@ -136,21 +132,21 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(x))))
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g)
WalkingPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g)
assert str(g) == "FunctionGraph(x)"
def test_nested_odd(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g)
WalkingPatternNodeRewriter((op1, (op1, "1")), "1").rewrite(g)
assert str(g) == "FunctionGraph(Op1(x))"
def test_expand(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(x)))
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g)
WalkingPatternNodeRewriter((op1, "1"), (op2, (op1, "1")), ign=True).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op1(Op2(Op1(Op2(Op1(x)))))))"
def test_ambiguous(self):
......@@ -169,7 +165,7 @@ class TestPatternNodeRewriter:
z = Constant(MyType(), 2, name="z")
e = op1(op1(x, y), y)
g = FunctionGraph([y], [e])
OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g)
WalkingPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(y, z{2}), y))"
def test_constraints(self):
......@@ -181,7 +177,7 @@ class TestPatternNodeRewriter:
# Only replacing if the input is an instance of Op2
return r.owner.op == op2
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op1, {"pattern": "1", "constraint": constraint}), (op3, "1")
).rewrite(g)
assert str(g) == "FunctionGraph(Op4(Op3(Op2(x, y)), Op1(Op1(x, y))))"
......@@ -190,7 +186,7 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(x, x)
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, "x", "y"), (op3, "x", "y")).rewrite(g)
WalkingPatternNodeRewriter((op1, "x", "y"), (op3, "x", "y")).rewrite(g)
assert str(g) == "FunctionGraph(Op3(x, x))"
@pytest.mark.xfail(
......@@ -202,10 +198,10 @@ class TestPatternNodeRewriter:
g = FunctionGraph([x, y, z], [e])
def constraint(r):
# Only replacing if the input is an instance of Op2
# Only replacing if the inputs are not identical
return r.owner.inputs[0] is not r.owner.inputs[1]
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
{"pattern": (op1, "x", "y"), "constraint": constraint}, (op3, "x", "y")
).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op1(x, x), Op3(x, y)))"
......@@ -220,7 +216,7 @@ class TestPatternNodeRewriter:
# So the replacement should fail
outputs = [e]
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op4, (op1, "x", "y")),
(op3, "x", "y"),
).rewrite(g)
......@@ -228,7 +224,7 @@ class TestPatternNodeRewriter:
# Now it should be fine
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op4, (op1, "x", "y")),
(op3, "x", "y"),
allow_multiple_clients=True,
......@@ -237,7 +233,7 @@ class TestPatternNodeRewriter:
# The fact that the inputs of the pattern have multiple clients should not matter
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op3, (op4, "w"), "w"),
(op3, "w", "w"),
allow_multiple_clients=False,
......@@ -252,7 +248,7 @@ class TestPatternNodeRewriter:
outputs = [e1, e2]
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=False,
......@@ -261,7 +257,7 @@ class TestPatternNodeRewriter:
outputs = [e1, e3]
g = FunctionGraph([x, y, z], outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=False,
......@@ -269,7 +265,7 @@ class TestPatternNodeRewriter:
assert equal_computations(g.outputs, outputs)
g = FunctionGraph(inputs, outputs, copy_inputs=False)
OpKeyPatternNodeRewriter(
WalkingPatternNodeRewriter(
(op4, (op4, "e")),
"e",
allow_multiple_clients=True,
......@@ -281,33 +277,13 @@ class TestPatternNodeRewriter:
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op_y(x, y), z)
g = FunctionGraph([x, y, z], [e])
OpKeyPatternNodeRewriter((op1, (op_z, "1", "2"), "3"), (op4, "3", "2")).rewrite(
g
)
WalkingPatternNodeRewriter(
(op1, (op_z, "1", "2"), "3"), (op4, "3", "2")
).rewrite(g)
str_g = str(g)
assert str_g == "FunctionGraph(Op4(z, y))"
def KeyedSubstitutionNodeRewriter(op1, op2):
return OpKeyGraphRewriter(SubstitutionNodeRewriter(op1, op2))
class TestSubstitutionNodeRewriter:
def test_straightforward(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op1(op1(op1(op1(x)))))
g = FunctionGraph([x, y, z], [e])
KeyedSubstitutionNodeRewriter(op1, op2).rewrite(g)
assert str(g) == "FunctionGraph(Op2(Op2(Op2(Op2(Op2(x))))))"
def test_straightforward_2(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
e = op1(op2(x), op3(y), op4(z))
g = FunctionGraph([x, y, z], [e])
KeyedSubstitutionNodeRewriter(op3, op4).rewrite(g)
assert str(g) == "FunctionGraph(Op1(Op2(x), Op4(y), Op4(z)))"
class NoInputOp(Op):
__props__ = ("param",)
......
......@@ -10,7 +10,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter,
OpKeyGraphRewriter,
PatternNodeRewriter,
SubstitutionNodeRewriter,
WalkingGraphRewriter,
......@@ -21,7 +20,7 @@ from tests.unittest_tools import assertFailure_fast
def OpKeyPatternNodeRewriter(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def TopoSubstitutionNodeRewriter(
......
......@@ -2,92 +2,12 @@ import pytest
import pytensor.tensor as pt
from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Apply, Variable, equal_computations
from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate
from pytensor.graph.basic import equal_computations
from pytensor.graph.features import Feature, FullHistory, ReplaceValidate
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from tests.graph.utils import MyVariable, op1
class TestNodeFinder:
def test_straightforward(self):
class MyType(Type):
def __init__(self, name):
self.name = name
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
return isinstance(other, MyType)
class MyOp(Op):
__props__ = ("nin", "name")
def __init__(self, nin, name):
self.nin = nin
self.name = name
def make_node(self, *inputs):
def as_variable(x):
assert isinstance(x, Variable)
return x
assert len(inputs) == self.nin
inputs = list(map(as_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType(self.name + "_R")()]
return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def perform(self, *args, **kwargs):
raise NotImplementedError()
sigmoid = MyOp(1, "Sigmoid")
add = MyOp(2, "Add")
dot = MyOp(2, "Dot")
def MyVariable(name):
return Variable(MyType(name), None, None)
def inputs():
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
return x, y, z
x, y, z = inputs()
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = FunctionGraph([x, y, z], [e], clone=False)
g.attach_feature(NodeFinder())
assert hasattr(g, "get_nodes")
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if len(list(g.get_nodes(type))) != num:
raise Exception(f"Expected: {num} times {type}")
new_e0 = add(y, z)
assert e0.owner in g.get_nodes(dot)
assert new_e0.owner not in g.get_nodes(add)
g.replace(e0, new_e0)
assert e0.owner not in g.get_nodes(dot)
assert new_e0.owner in g.get_nodes(add)
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if len(list(g.get_nodes(type))) != num:
raise Exception(f"Expected: {num} times {type}")
class TestReplaceValidate:
def test_verbose(self, capsys):
var1 = MyVariable("var1")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论