提交 6e9b36e8 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5587 from nouiz/scan_fix

Scan fix
...@@ -1516,15 +1516,10 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1516,15 +1516,10 @@ class ScanSaveMem(gof.Optimizer):
node_ins] node_ins]
node_ins = pre_constant_merge(node_ins) node_ins = pre_constant_merge(node_ins)
# 3.6 Compose the new scan # 3.6 Compose the new scan
# I need to make sure I'm not reapplying the same optimization
# twice since bad things usually happen if I do that
# TODO: why not check if save mem was done on any of merged nodes?
# That way, if none of them had save mem applied, it would
# be applied later.
info['_scan_savemem_visited'] = True
# TODO: currently we don't support scan with 0 step. So # TODO: currently we don't support scan with 0 step. So
# don't create one. # don't create one.
# For test, mark that savemem have optimized this node
info['_scan_savemem_visited'] = True
if theano.tensor.extract_constant(node_ins[0]) == 0: if theano.tensor.extract_constant(node_ins[0]) == 0:
return return
...@@ -1627,8 +1622,7 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1627,8 +1622,7 @@ class ScanSaveMem(gof.Optimizer):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, nodelist = [x for x in fgraph.toposort() if isinstance(x.op,
scan_op.Scan)] scan_op.Scan)]
for node in nodelist: for node in nodelist:
if not hasattr(node.op, '_scan_savemem_visited'): self.process_node(fgraph, node)
self.process_node(fgraph, node)
class ScanMerge(gof.Optimizer): class ScanMerge(gof.Optimizer):
...@@ -1819,8 +1813,6 @@ class ScanMerge(gof.Optimizer): ...@@ -1819,8 +1813,6 @@ class ScanMerge(gof.Optimizer):
""" """
rep = set_nodes[0] rep = set_nodes[0]
if (rep.op.as_while != node.op.as_while or if (rep.op.as_while != node.op.as_while or
len(rep.inputs) != len(node.inputs) or
len(rep.outputs) != len(node.outputs) or
node.op.truncate_gradient != rep.op.truncate_gradient or node.op.truncate_gradient != rep.op.truncate_gradient or
node.op.mode != rep.op.mode): node.op.mode != rep.op.mode):
return False return False
...@@ -2266,6 +2258,8 @@ optdb.register('scan_eqopt1', scan_eqopt1, .05, 'fast_run', 'scan') ...@@ -2266,6 +2258,8 @@ optdb.register('scan_eqopt1', scan_eqopt1, .05, 'fast_run', 'scan')
# We run before blas opt at 1.7 and specialize 2.0 # We run before blas opt at 1.7 and specialize 2.0
# but after stabilize at 1.5. Should we put it before stabilize? # but after stabilize at 1.5. Should we put it before stabilize?
optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan') optdb.register('scan_eqopt2', scan_eqopt2, 1.6, 'fast_run', 'scan')
# ScanSaveMem should execute only once per node.
optdb.register('scanOp_save_mem', ScanSaveMem(), 1.61, 'fast_run', 'scan')
optdb.register('scanOp_make_inplace', optdb.register('scanOp_make_inplace',
ScanInplaceOptimizer(typeInfer=None, ScanInplaceOptimizer(typeInfer=None,
gpu_flag=False), gpu_flag=False),
...@@ -2359,15 +2353,6 @@ scan_eqopt2.register('scanOp_merge_inouts', ...@@ -2359,15 +2353,6 @@ scan_eqopt2.register('scanOp_merge_inouts',
'fast_run', 'fast_run',
'scan') 'scan')
# Just before specialize to have the other optimization
# like constant folding being applied
# This don't introduce inplace.
scan_eqopt2.register('scanOp_save_mem',
ScanSaveMem(),
7,
'fast_run',
'scan')
# After everything else # After everything else
scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3', scan_eqopt2.register('scanOp_remove_constants_and_unused_inputs3',
opt.in2out(remove_constants_and_unused_inputs_scan, opt.in2out(remove_constants_and_unused_inputs_scan,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论