提交 219a9516 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix buggy node recreation approach in scan optimizations

上级 f31b78d9
...@@ -74,6 +74,7 @@ from aesara.graph.basic import ( ...@@ -74,6 +74,7 @@ from aesara.graph.basic import (
) )
from aesara.graph.destroyhandler import DestroyHandler from aesara.graph.destroyhandler import DestroyHandler
from aesara.graph.fg import InconsistencyError from aesara.graph.fg import InconsistencyError
from aesara.graph.op import compute_test_value
from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.graph.toolbox import ReplaceValidate from aesara.graph.toolbox import ReplaceValidate
...@@ -349,8 +350,10 @@ class PushOutNonSeqScan(GlobalOptimizer): ...@@ -349,8 +350,10 @@ class PushOutNonSeqScan(GlobalOptimizer):
x.type.filter_variable(y) for x, y in zip(nd.inputs, outside_ins) x.type.filter_variable(y) for x, y in zip(nd.inputs, outside_ins)
] ]
# Do not call make_node for test_value nw_outer_node = nd.op.make_node(*outside_ins)
nw_outer_node = nd.op(*outside_ins, **dict(return_list=True))[0].owner
if config.compute_test_value != "off":
compute_test_value(nw_outer_node)
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
...@@ -571,7 +574,10 @@ class PushOutSeqScan(GlobalOptimizer): ...@@ -571,7 +574,10 @@ class PushOutSeqScan(GlobalOptimizer):
to_remove_set.add(nd) to_remove_set.add(nd)
# Do not call make_node for test_value # Do not call make_node for test_value
nw_outer_node = nd.op(*outside_ins, **dict(return_list=True))[0].owner nw_outer_node = nd.op.make_node(*outside_ins)
if config.compute_test_value != "off":
compute_test_value(nw_outer_node)
# Step 2. Create variables for replacements # Step 2. Create variables for replacements
for idx, y in enumerate(nd.outputs): for idx, y in enumerate(nd.outputs):
...@@ -1033,7 +1039,17 @@ class ScanInplaceOptimizer(GlobalOptimizer): ...@@ -1033,7 +1039,17 @@ class ScanInplaceOptimizer(GlobalOptimizer):
and inp.owner and inp.owner
and isinstance(inp.owner.op, alloc_ops) and isinstance(inp.owner.op, alloc_ops)
): ):
ls[i] = inp.owner.op(*inp.owner.inputs) new_lsi = inp.owner.op.make_node(*inp.owner.inputs)
if config.compute_test_value != "off":
compute_test_value(new_lsi)
new_lsi_out = new_lsi.outputs
if len(new_lsi_out) == 1:
new_lsi_out = new_lsi_out[0]
ls[i] = new_lsi_out
n_outs = len(ls) n_outs = len(ls)
for idx in range(n_outs): for idx in range(n_outs):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论