提交 44ffb704 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #1339 from lamblin/fix_merge_abs

Fix bug in local_abs_merge optimization
...@@ -3659,7 +3659,10 @@ def local_abs_merge(node): ...@@ -3659,7 +3659,10 @@ def local_abs_merge(node):
if i.owner and i.owner.op == T.abs_: if i.owner and i.owner.op == T.abs_:
inputs.append(i.owner.inputs[0]) inputs.append(i.owner.inputs[0])
else: else:
try:
const = get_scalar_constant_value(i) const = get_scalar_constant_value(i)
except NotScalarConstantError:
return False
if not (const >= 0).all(): if not (const >= 0).all():
return False return False
inputs.append(i) inputs.append(i)
......
...@@ -782,6 +782,19 @@ def test_local_merge_abs(): ...@@ -782,6 +782,19 @@ def test_local_merge_abs():
assert len(f.maker.fgraph.toposort()) == 2 assert len(f.maker.fgraph.toposort()) == 2
def test_merge_abs_bugfix():
# Test crash in optimization reported by Jeremiah Lowin at
# https://groups.google.com/d/topic/theano-users/TaXfqXP2Mj0/discussion
input = T.matrix()
# normalize on cols
step1 = input / input.sum(0)
# normalize on rows
step2 = step1 / step1.sum(1)
# get l1 norm
l1_norm = T.abs_(step2).sum()
theano.function([input], T.grad(l1_norm, input))
def test_mixeddiv(): def test_mixeddiv():
"""Test that int division is preserved""" """Test that int division is preserved"""
i = iscalar() i = iscalar()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论