提交 d475421b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make local_optimizer automatically validate using tracks

上级 1b630dc6
...@@ -1187,10 +1187,19 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -1187,10 +1187,19 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def __init__(self, fn, tracks=None, requirements=()): def __init__(self, fn, tracks=None, requirements=()):
self.fn = fn self.fn = fn
self._tracks = tracks self._tracks = tracks
self._tracked_types = (
tuple(t for t in tracks if isinstance(t, type)) if tracks else ()
)
self.requirements = requirements self.requirements = requirements
def transform(self, *args, **kwargs): def transform(self, fgraph, node):
return self.fn(*args, **kwargs) if self._tracks:
if not (
node.op in self._tracks or isinstance(node.op, self._tracked_types)
):
return False
return self.fn(fgraph, node)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
for req in self.requirements: for req in self.requirements:
...@@ -1218,13 +1227,15 @@ def local_optimizer( ...@@ -1218,13 +1227,15 @@ def local_optimizer(
Parameters Parameters
---------- ----------
tracks : tracks
The `Op` types or instances to which this optimization applies. The `Op` types or instances to which this optimization applies.
inplace : Use ``None`` instead of an empty list to have the optimization apply to
all `Op`s`.
inplace
A boolean indicating whether or not the optimization works in-place. A boolean indicating whether or not the optimization works in-place.
If ``True``, a `DestroyHandler` `Feature` is added automatically added If ``True``, a `DestroyHandler` `Feature` is added automatically added
to the `FunctionGraph`\s applied to this optimization. to the `FunctionGraph`\s applied to this optimization.
requirements : requirements
`Feature` types required by this optimization. `Feature` types required by this optimization.
""" """
...@@ -1236,14 +1247,14 @@ def local_optimizer( ...@@ -1236,14 +1247,14 @@ def local_optimizer(
if tracks is not None: if tracks is not None:
if len(tracks) == 0: if len(tracks) == 0:
raise ValueError( raise ValueError(
"Use None instead of an empty list to apply to all nodes.", "Use `None` instead of an empty list to make an optimization apply to all nodes."
f.__module__,
f.__name__,
) )
for t in tracks: for t in tracks:
if not (isinstance(t, Op) or issubclass(t, Op)): if not (
raise ValueError( isinstance(t, Op) or (isinstance(t, type) and issubclass(t, Op))
"Tracks are op classes or instances", f.__module__, f.__name__ ):
raise TypeError(
"`tracks` must consist of `Op` classes or instances."
) )
req = requirements req = requirements
if inplace: if inplace:
......
...@@ -704,3 +704,54 @@ def test_local_optimizer_str(): ...@@ -704,3 +704,54 @@ def test_local_optimizer_str():
assert res.startswith("FromFunctionLocalOptimizer(") assert res.startswith("FromFunctionLocalOptimizer(")
assert "Op1" in res assert "Op1" in res
assert "local_opt_1" in res assert "local_opt_1" in res
def test_local_optimizer():
with pytest.raises(ValueError):
@local_optimizer([])
def local_bad_1(fgraph, node):
return node.outputs
with pytest.raises(TypeError):
@local_optimizer([None])
def local_bad_2(fgraph, node):
return node.outputs
x = MyVariable("x")
y = MyVariable("y")
o1 = op1(x, y)
class MyNewOp(MyOp):
pass
o2 = MyNewOp("MyNewOp")(x, y)
class MyNewOp2(MyOp):
pass
o3 = MyNewOp2("MyNewOp2")(x, y)
fgraph = FunctionGraph([x, y], [o1, o2, o3], clone=False)
hits = [0]
@local_optimizer([op1, MyNewOp])
def local_opt_1(fgraph, node, hits=hits):
hits[0] += 1
return node.outputs
# This is allowed by the `op1` in `tracks`
local_opt_1.transform(fgraph, fgraph.outputs[0].owner)
assert hits[0] == 1
# This is allowed by the `MyOp` in `tracks`
local_opt_1.transform(fgraph, fgraph.outputs[1].owner)
assert hits[0] == 2
# This is not allowed by `tracks`
local_opt_1.transform(fgraph, fgraph.outputs[2].owner)
assert hits[0] == 2
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论