提交 2d46d60e authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Remove unused rewrites and functionality

上级 40ccab1a
...@@ -134,7 +134,7 @@ computation graph. ...@@ -134,7 +134,7 @@ computation graph.
In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`, In a nutshell, :class:`ReplaceValidate` grants access to :meth:`fgraph.replace_validate`,
and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with and :meth:`fgraph.replace_validate` allows us to replace a :class:`Variable` with
another while respecting certain validation constraints. As an another while respecting certain validation constraints. As an
exercise, try to rewrite :class:`Simplify` using :class:`NodeFinder`. (Hint: you exercise, try to rewrite :class:`Simplify` using :class:`WalkingGraphRewriter`. (Hint: you
want to use the method it publishes instead of the call to toposort) want to use the method it publishes instead of the call to toposort)
Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by Then, in :meth:`GraphRewriter.apply` we do the actual job of simplification. We start by
......
...@@ -26,7 +26,3 @@ Guide ...@@ -26,7 +26,3 @@ Guide
.. class:: ReplaceValidate(History, Validator) .. class:: ReplaceValidate(History, Validator)
.. method:: replace_validate(fgraph, var, new_var, reason=None) .. method:: replace_validate(fgraph, var, new_var, reason=None)
.. class:: NodeFinder(Bookkeeper)
.. class:: PrintListener(object)
...@@ -827,100 +827,6 @@ class ReplaceValidate(History, Validator): ...@@ -827,100 +827,6 @@ class ReplaceValidate(History, Validator):
raise InconsistencyError("Trying to reintroduce a removed node") raise InconsistencyError("Trying to reintroduce a removed node")
class NodeFinder(Bookkeeper):
def __init__(self):
self.fgraph = None
self.d = {}
def on_attach(self, fgraph):
if hasattr(fgraph, "get_nodes"):
raise AlreadyThere("NodeFinder is already present")
if self.fgraph is not None and self.fgraph != fgraph:
raise Exception("A NodeFinder instance can only serve one FunctionGraph.")
self.fgraph = fgraph
fgraph.get_nodes = partial(self.query, fgraph)
Bookkeeper.on_attach(self, fgraph)
def clone(self):
return type(self)()
def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.fgraph is not fgraph:
raise Exception(
"This NodeFinder instance was not attached to the provided fgraph."
)
self.fgraph = None
del fgraph.get_nodes
Bookkeeper.on_detach(self, fgraph)
def on_import(self, fgraph, node, reason):
try:
self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
return
except Exception as e:
print("OFFENDING node", type(node), type(node.op), file=sys.stderr) # noqa: T201
try:
print("OFFENDING node hash", hash(node.op), file=sys.stderr) # noqa: T201
except Exception:
print("OFFENDING node not hashable", file=sys.stderr) # noqa: T201
raise e
def on_prune(self, fgraph, node, reason):
try:
nodes = self.d[node.op]
except TypeError: # node.op is unhashable
return
nodes.remove(node)
if not nodes:
del self.d[node.op]
def query(self, fgraph, op):
try:
all = self.d.get(op, [])
except TypeError:
raise TypeError(
f"{op} in unhashable and cannot be queried by the optimizer"
)
all = list(all)
return all
class PrintListener(Feature):
def __init__(self, active=True):
self.active = active
def on_attach(self, fgraph):
if self.active:
print("-- attaching to: ", fgraph) # noqa: T201
def on_detach(self, fgraph):
"""
Should remove any dynamically added functionality
that it installed into the function_graph
"""
if self.active:
print("-- detaching from: ", fgraph) # noqa: T201
def on_import(self, fgraph, node, reason):
if self.active:
print(f"-- importing: {node}, reason: {reason}") # noqa: T201
def on_prune(self, fgraph, node, reason):
if self.active:
print(f"-- pruning: {node}, reason: {reason}") # noqa: T201
def on_change_input(self, fgraph, node, i, r, new_r, reason=None):
if self.active:
print(f"-- changing ({node}.inputs[{i}]) from {r} to {new_r}") # noqa: T201
class PreserveVariableAttributes(Feature): class PreserveVariableAttributes(Feature):
""" """
This preserve some variables attributes and tag during optimization. This preserve some variables attributes and tag during optimization.
......
...@@ -34,7 +34,6 @@ from pytensor.graph.basic import Constant ...@@ -34,7 +34,6 @@ from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
NodeRewriter, NodeRewriter,
RemovalNodeRewriter,
Rewriter, Rewriter,
copy_stack_trace, copy_stack_trace,
in2out, in2out,
...@@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node): ...@@ -1224,7 +1223,10 @@ def local_merge_alloc(fgraph, node):
return [alloc(inputs_inner[0], *dims_outer)] return [alloc(inputs_inner[0], *dims_outer)]
register_canonicalize(RemovalNodeRewriter(tensor_copy), name="remove_tensor_copy") @register_canonicalize
@node_rewriter(tracks=[tensor_copy])
def remove_tensor_copy(fgraph, node):
return node.inputs
@register_specialize @register_specialize
......
...@@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10): ...@@ -3162,13 +3162,6 @@ def isclose(x, ref, rtol=0, atol=0, num_ulps=10):
return np.allclose(x, ref, rtol=rtol, atol=atol) return np.allclose(x, ref, rtol=rtol, atol=atol)
def _skip_mul_1(r):
if r.owner and r.owner.op == mul:
not_is_1 = [i for i in r.owner.inputs if not _is_1(i)]
if len(not_is_1) == 1:
return not_is_1[0]
def _is_1(expr): def _is_1(expr):
""" """
...@@ -3190,7 +3183,6 @@ logsigm_to_softplus = PatternNodeRewriter( ...@@ -3190,7 +3183,6 @@ logsigm_to_softplus = PatternNodeRewriter(
(neg, (softplus, (neg, "x"))), (neg, (softplus, (neg, "x"))),
allow_multiple_clients=True, allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
tracks=[sigmoid], tracks=[sigmoid],
get_nodes=get_clients_at_depth1, get_nodes=get_clients_at_depth1,
) )
...@@ -3199,7 +3191,6 @@ log1msigm_to_softplus = PatternNodeRewriter( ...@@ -3199,7 +3191,6 @@ log1msigm_to_softplus = PatternNodeRewriter(
(neg, (softplus, "x")), (neg, (softplus, "x")),
allow_multiple_clients=True, allow_multiple_clients=True,
values_eq_approx=values_eq_approx_remove_inf, values_eq_approx=values_eq_approx_remove_inf,
skip_identities_fn=_skip_mul_1,
tracks=[sigmoid], tracks=[sigmoid],
get_nodes=get_clients_at_depth2, get_nodes=get_clients_at_depth2,
) )
......
...@@ -13,7 +13,7 @@ from pytensor.compile.io import In, Out ...@@ -13,7 +13,7 @@ from pytensor.compile.io import In, Out
from pytensor.compile.mode import Mode, get_default_mode from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.rewriting.basic import OpKeyGraphRewriter, PatternNodeRewriter from pytensor.graph.rewriting.basic import PatternNodeRewriter, WalkingGraphRewriter
from pytensor.graph.utils import MissingInputError from pytensor.graph.utils import MissingInputError
from pytensor.link.vm import VMLinker from pytensor.link.vm import VMLinker
from pytensor.printing import debugprint from pytensor.printing import debugprint
...@@ -39,7 +39,7 @@ pytestmark = pytest.mark.filterwarnings("error") ...@@ -39,7 +39,7 @@ pytestmark = pytest.mark.filterwarnings("error")
def PatternOptimizer(p1, p2, ign=True): def PatternOptimizer(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
class TestFunction: class TestFunction:
......
...@@ -10,7 +10,6 @@ from pytensor.graph.fg import FunctionGraph ...@@ -10,7 +10,6 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import ( from pytensor.graph.rewriting.basic import (
NodeProcessingGraphRewriter, NodeProcessingGraphRewriter,
OpKeyGraphRewriter,
PatternNodeRewriter, PatternNodeRewriter,
SubstitutionNodeRewriter, SubstitutionNodeRewriter,
WalkingGraphRewriter, WalkingGraphRewriter,
...@@ -21,7 +20,7 @@ from tests.unittest_tools import assertFailure_fast ...@@ -21,7 +20,7 @@ from tests.unittest_tools import assertFailure_fast
def OpKeyPatternNodeRewriter(p1, p2, ign=True): def OpKeyPatternNodeRewriter(p1, p2, ign=True):
return OpKeyGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign) return WalkingGraphRewriter(PatternNodeRewriter(p1, p2), ignore_newtrees=ign)
def TopoSubstitutionNodeRewriter( def TopoSubstitutionNodeRewriter(
......
...@@ -2,92 +2,12 @@ import pytest ...@@ -2,92 +2,12 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.graph import rewrite_graph from pytensor.graph import rewrite_graph
from pytensor.graph.basic import Apply, Variable, equal_computations from pytensor.graph.basic import equal_computations
from pytensor.graph.features import Feature, FullHistory, NodeFinder, ReplaceValidate from pytensor.graph.features import Feature, FullHistory, ReplaceValidate
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.type import Type
from tests.graph.utils import MyVariable, op1 from tests.graph.utils import MyVariable, op1
class TestNodeFinder:
def test_straightforward(self):
class MyType(Type):
def __init__(self, name):
self.name = name
def filter(self, *args, **kwargs):
raise NotImplementedError()
def __str__(self):
return self.name
def __repr__(self):
return self.name
def __eq__(self, other):
return isinstance(other, MyType)
class MyOp(Op):
__props__ = ("nin", "name")
def __init__(self, nin, name):
self.nin = nin
self.name = name
def make_node(self, *inputs):
def as_variable(x):
assert isinstance(x, Variable)
return x
assert len(inputs) == self.nin
inputs = list(map(as_variable, inputs))
for input in inputs:
if not isinstance(input.type, MyType):
raise Exception("Error 1")
outputs = [MyType(self.name + "_R")()]
return Apply(self, inputs, outputs)
def __str__(self):
return self.name
def perform(self, *args, **kwargs):
raise NotImplementedError()
sigmoid = MyOp(1, "Sigmoid")
add = MyOp(2, "Add")
dot = MyOp(2, "Dot")
def MyVariable(name):
return Variable(MyType(name), None, None)
def inputs():
x = MyVariable("x")
y = MyVariable("y")
z = MyVariable("z")
return x, y, z
x, y, z = inputs()
e0 = dot(y, z)
e = add(add(sigmoid(x), sigmoid(sigmoid(z))), dot(add(x, y), e0))
g = FunctionGraph([x, y, z], [e], clone=False)
g.attach_feature(NodeFinder())
assert hasattr(g, "get_nodes")
for type, num in ((add, 3), (sigmoid, 3), (dot, 2)):
if len(list(g.get_nodes(type))) != num:
raise Exception(f"Expected: {num} times {type}")
new_e0 = add(y, z)
assert e0.owner in g.get_nodes(dot)
assert new_e0.owner not in g.get_nodes(add)
g.replace(e0, new_e0)
assert e0.owner not in g.get_nodes(dot)
assert new_e0.owner in g.get_nodes(add)
for type, num in ((add, 4), (sigmoid, 3), (dot, 1)):
if len(list(g.get_nodes(type))) != num:
raise Exception(f"Expected: {num} times {type}")
class TestReplaceValidate: class TestReplaceValidate:
def test_verbose(self, capsys): def test_verbose(self, capsys):
var1 = MyVariable("var1") var1 = MyVariable("var1")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论