提交 9d371eee authored 作者: Razvan Pascanu's avatar Razvan Pascanu

corner case that I forgot to include in scan optimization

上级 f9832742
...@@ -1072,7 +1072,7 @@ class ScanSpaceOptimizer(Optimizer): ...@@ -1072,7 +1072,7 @@ class ScanSpaceOptimizer(Optimizer):
for i,out in enumerate(node.outputs): for i,out in enumerate(node.outputs):
if op.store_steps[i] == 0 : if op.store_steps[i] == 0 :
# if we do not have a range for this output # if we do not have a range for this output
req_steps = 0 req_steps = numpy.max(numpy.abs(op.outs_taps.get(i,1)))
# look at all its clients # look at all its clients
for cl,_dx in out.clients: for cl,_dx in out.clients:
if type(cl) == str: if type(cl) == str:
...@@ -1091,14 +1091,14 @@ class ScanSpaceOptimizer(Optimizer): ...@@ -1091,14 +1091,14 @@ class ScanSpaceOptimizer(Optimizer):
# if it is a tensor, and the first # if it is a tensor, and the first
# dimension is just -1 # dimension is just -1
if cl.op.idx_list[0] == -1 : if cl.op.idx_list[0] == -1 :
req_steps = 1 req_steps = numpy.max([1, req_steps])
else: else:
# or a constant that evaluates to # or a constant that evaluates to
# -1 # -1
try: try:
idx = opt.get_constant_value(cl.op.idx_list[0]) idx = opt.get_constant_value(cl.op.idx_list[0])
if idx== -1: if idx== -1:
req_steps = 1 req_steps = numpy.max([1, req_steps])
else: else:
req_steps = 0 req_steps = 0
break break
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论