提交 c8bbcf99 authored 作者: Frederic's avatar Frederic

Fix crash with test_value and empty elemwise

上级 a664e329
...@@ -167,6 +167,23 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -167,6 +167,23 @@ class TestComputeTestValue(unittest.TestCase):
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
def test_empty_elemwise(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'raise'
x = theano.shared(numpy.random.rand(0, 6).astype(config.floatX),
'x')
# should work
z = (x + 2) * 3
assert hasattr(z.tag, 'test_value')
f = theano.function([], z)
assert _allclose(f(), z.tag.test_value)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_constant(self): def test_constant(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
......
...@@ -5433,7 +5433,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024, ...@@ -5433,7 +5433,11 @@ def local_elemwise_fusion_op(OP, max_input_fct=lambda node: 1024,
else: else:
tmp = scalar.get_scalar_type(ii.dtype).make_variable() tmp = scalar.get_scalar_type(ii.dtype).make_variable()
try: try:
tmp.tag.test_value = gof.op.get_test_value(ii).flatten()[0] tv = gof.op.get_test_value(ii)
if tv.size > 0:
tmp.tag.test_value = tv.flatten()[0]
else:
tmp.tag.test_value = tv
except AttributeError: except AttributeError:
pass pass
tmp_s_input.append(tmp) tmp_s_input.append(tmp)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论