提交 b16c4fda authored 作者: Frederic's avatar Frederic

use op.__call__ instead of make_node in opt for test_value.

上级 79c209b0
...@@ -144,7 +144,7 @@ def remove_constants_and_unused_inputs_scan(node): ...@@ -144,7 +144,7 @@ def remove_constants_and_unused_inputs_scan(node):
nw_info['n_seqs'] = nw_n_seqs nw_info['n_seqs'] = nw_n_seqs
# DEBUG CHECK # DEBUG CHECK
nwScan = scan_op.Scan(nw_inner, op_outs, nw_info) nwScan = scan_op.Scan(nw_inner, op_outs, nw_info)
nw_outs = nwScan.make_node(*nw_outer).outputs nw_outs = nwScan(*nw_outer, return_list=True)
return nw_outs return nw_outs
else: else:
return False return False
...@@ -574,7 +574,8 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -574,7 +574,8 @@ class ScanInplaceOptimizer(Optimizer):
info, info,
typeConstructor=self.typeConstructor) typeConstructor=self.typeConstructor)
new_outs = new_op.make_node(*inputs).outputs # Do not call make_node for test_value
new_outs = new_op(*inputs, return_list=True)
try: try:
fgraph.replace_all_validate_remove( fgraph.replace_all_validate_remove(
zip(node.outputs, new_outs), zip(node.outputs, new_outs),
...@@ -957,9 +958,10 @@ class ScanSaveMem(gof.Optimizer): ...@@ -957,9 +958,10 @@ class ScanSaveMem(gof.Optimizer):
# I need to make sure I'm not reapplying the same optimization # I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that # twice since bad things usually happen if I do that
info['_scan_savemem_visited'] = True info['_scan_savemem_visited'] = True
new_outs = scan_op.Scan(inps,
outs, # Do not call make_node for test_value
info).make_node(*node_ins).outputs new_outs = scan_op.Scan(inps, outs, info)(*node_ins,
return_list=True)
old_new = [] old_new = []
# 3.7 Get replace pairs for those outputs that do not change # 3.7 Get replace pairs for those outputs that do not change
...@@ -989,8 +991,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -989,8 +991,7 @@ class ScanSaveMem(gof.Optimizer):
nw_slice, nw_slice,
lambda entry: isinstance(entry, lambda entry: isinstance(entry,
tensor.Variable)) tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos], new_o = subtens(new_outs[nw_pos], *sl_ins)
*sl_ins).outputs[0]
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]] new_o = new_o[::cnf_slice[1]]
replaced_outs.append(idx) replaced_outs.append(idx)
...@@ -1028,8 +1029,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1028,8 +1029,7 @@ class ScanSaveMem(gof.Optimizer):
nw_slice, nw_slice,
lambda entry: isinstance(entry, lambda entry: isinstance(entry,
tensor.Variable)) tensor.Variable))
new_o = subtens.make_node(new_outs[nw_pos], new_o = subtens(new_outs[nw_pos], *sl_ins)
*sl_ins).outputs[0]
if new_o.ndim > 0: if new_o.ndim > 0:
new_o = new_o[::cnf_slice[1]] new_o = new_o[::cnf_slice[1]]
old_new += [(old, new_o)] old_new += [(old, new_o)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论