提交 27d2bfe3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make NoOutputFromInplace take arbitrary collections

上级 beb510e6
...@@ -742,15 +742,20 @@ class PreserveVariableAttributes(Feature): ...@@ -742,15 +742,20 @@ class PreserveVariableAttributes(Feature):
class NoOutputFromInplace(Feature): class NoOutputFromInplace(Feature):
"""Prevent `FunctionGraph` outputs within a range from being altered in-place.""" """Prevent `FunctionGraph` outputs within a range from being altered in-place."""
def __init__(self, first_output_idx=0, last_output_idx=None): def __init__(self, protected_out_ids):
self.first_idx = first_output_idx self.protected_out_ids = tuple(protected_out_ids)
self.last_idx = last_output_idx
def on_attach(self, fgraph):
if hasattr(fgraph, "_no_output_from_inplace"):
raise AlreadyThere(f"InnerGraphWatcher is already attached to {fgraph}.")
fgraph._no_output_from_inplace = self
def validate(self, fgraph): def validate(self, fgraph):
if not hasattr(fgraph, "destroyers"): if not hasattr(fgraph, "destroyers"):
return True return True
for out in fgraph.outputs[self.first_idx : self.last_idx]: for out in tuple(fgraph.outputs[i] for i in self.protected_out_ids):
node = out.owner node = out.owner
...@@ -768,3 +773,5 @@ class NoOutputFromInplace(Feature): ...@@ -768,3 +773,5 @@ class NoOutputFromInplace(Feature):
f"operations. This has prevented the output {out} from " f"operations. This has prevented the output {out} from "
"being computed by modifying another variable in-place." "being computed by modifying another variable in-place."
) )
return True
...@@ -877,7 +877,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -877,7 +877,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitsot_start = info.n_mit_mot_outs - len(self.preallocated_mitmot_outs) mitsot_start = info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot nitsot_end = mitsot_start + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
features.append(NoOutputFromInplace(mitsot_start, nitsot_end)) features.append(NoOutputFromInplace(range(mitsot_start, nitsot_end)))
self.fgraph = FunctionGraph( self.fgraph = FunctionGraph(
inputs, inputs,
......
...@@ -15,25 +15,30 @@ def test_Mode_basic(): ...@@ -15,25 +15,30 @@ def test_Mode_basic():
assert str(mode).startswith("Mode(linker=py, optimizer=OptimizationQuery") assert str(mode).startswith("Mode(linker=py, optimizer=OptimizationQuery")
def test_no_output_from_implace(): def test_NoOutputFromInplace():
x = matrix() x = matrix()
y = matrix() y = matrix()
a = dot(x, y) a = dot(x, y)
b = tanh(a) b = tanh(a)
c = tanh(dot(2 * x, y))
# Ensure that the elemwise op that produces the output is inplace when # Ensure that the elemwise op that produces the output is inplace when
# using a mode that does not include the optimization # using a mode that does not include the optimization
fct_no_opt = function([x, y], b, mode="FAST_RUN") fct_no_opt = function([x, y], [b, c], mode="FAST_RUN")
op = fct_no_opt.maker.fgraph.outputs[0].owner.op op = fct_no_opt.maker.fgraph.outputs[0].owner.op
assert op.destroy_map and 0 in op.destroy_map assert op.destroy_map and 0 in op.destroy_map
op = fct_no_opt.maker.fgraph.outputs[1].owner.op
assert op.destroy_map and 0 in op.destroy_map
# Ensure that the elemwise op that produces the output is not inplace when # Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization # using a mode that includes the optimization
opt = AddFeatureOptimizer(NoOutputFromInplace()) opt = AddFeatureOptimizer(NoOutputFromInplace([1]))
mode_opt = Mode(linker="py", optimizer="fast_run").register((opt, 49.9)) mode_opt = Mode(linker="py", optimizer="fast_run").register((opt, 49.9))
fct_opt = function([x, y], b, mode=mode_opt) fct_opt = function([x, y], [b, c], mode=mode_opt)
op = fct_opt.maker.fgraph.outputs[0].owner.op op = fct_opt.maker.fgraph.outputs[0].owner.op
assert op.destroy_map and 0 in op.destroy_map
op = fct_opt.maker.fgraph.outputs[1].owner.op
assert not op.destroy_map or 0 not in op.destroy_map assert not op.destroy_map or 0 not in op.destroy_map
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论