提交 27d2bfe3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make NoOutputFromInplace take arbitrary collections

上级 beb510e6
......@@ -742,15 +742,20 @@ class PreserveVariableAttributes(Feature):
class NoOutputFromInplace(Feature):
"""Prevent `FunctionGraph` outputs within a range from being altered in-place."""
def __init__(self, first_output_idx=0, last_output_idx=None):
self.first_idx = first_output_idx
self.last_idx = last_output_idx
def __init__(self, protected_out_ids):
self.protected_out_ids = tuple(protected_out_ids)
def on_attach(self, fgraph):
if hasattr(fgraph, "_no_output_from_inplace"):
raise AlreadyThere(f"InnerGraphWatcher is already attached to {fgraph}.")
fgraph._no_output_from_inplace = self
def validate(self, fgraph):
if not hasattr(fgraph, "destroyers"):
return True
for out in fgraph.outputs[self.first_idx : self.last_idx]:
for out in tuple(fgraph.outputs[i] for i in self.protected_out_ids):
node = out.owner
......@@ -768,3 +773,5 @@ class NoOutputFromInplace(Feature):
f"operations. This has prevented the output {out} from "
"being computed by modifying another variable in-place."
)
return True
......@@ -877,7 +877,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitsot_start = info.n_mit_mot_outs - len(self.preallocated_mitmot_outs)
nitsot_end = mitsot_start + info.n_mit_sot + info.n_sit_sot + info.n_nit_sot
features.append(NoOutputFromInplace(mitsot_start, nitsot_end))
features.append(NoOutputFromInplace(range(mitsot_start, nitsot_end)))
self.fgraph = FunctionGraph(
inputs,
......
......@@ -15,25 +15,30 @@ def test_Mode_basic():
assert str(mode).startswith("Mode(linker=py, optimizer=OptimizationQuery")
def test_no_output_from_implace():
def test_NoOutputFromInplace():
x = matrix()
y = matrix()
a = dot(x, y)
b = tanh(a)
c = tanh(dot(2 * x, y))
# Ensure that the elemwise op that produces the output is inplace when
# using a mode that does not include the optimization
fct_no_opt = function([x, y], b, mode="FAST_RUN")
fct_no_opt = function([x, y], [b, c], mode="FAST_RUN")
op = fct_no_opt.maker.fgraph.outputs[0].owner.op
assert op.destroy_map and 0 in op.destroy_map
op = fct_no_opt.maker.fgraph.outputs[1].owner.op
assert op.destroy_map and 0 in op.destroy_map
# Ensure that the elemwise op that produces the output is not inplace when
# using a mode that includes the optimization
opt = AddFeatureOptimizer(NoOutputFromInplace())
opt = AddFeatureOptimizer(NoOutputFromInplace([1]))
mode_opt = Mode(linker="py", optimizer="fast_run").register((opt, 49.9))
fct_opt = function([x, y], b, mode=mode_opt)
fct_opt = function([x, y], [b, c], mode=mode_opt)
op = fct_opt.maker.fgraph.outputs[0].owner.op
assert op.destroy_map and 0 in op.destroy_map
op = fct_opt.maker.fgraph.outputs[1].owner.op
assert not op.destroy_map or 0 not in op.destroy_map
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论