提交 1e24b90d authored 作者: Ian Goodfellow's avatar Ian Goodfellow

test reduce with shape 0

上级 0975c51f
...@@ -202,6 +202,11 @@ def test_max(): ...@@ -202,6 +202,11 @@ def test_max():
return gpu_pattern return gpu_pattern
for shape, pattern in [((1,1),(1,)), for shape, pattern in [((1,1),(1,)),
((1,0),(1,)),
((0,1),(1,)),
((0,0),(1,)),
((0,0,0),(1,2)),
((0,0,0,0),(1,2,3)),
((2,1),(1,)), ((2,1),(1,)),
((1,2),(1,)), ((1,2),(1,)),
((100,3,1300),[1]), ((100,3,1300),[1]),
...@@ -261,9 +266,43 @@ def test_max(): ...@@ -261,9 +266,43 @@ def test_max():
# val = numpy.arange(numpy.prod(shape)).reshape(shape) # val = numpy.arange(numpy.prod(shape)).reshape(shape)
val = theano._asarray(val, dtype='float32') val = theano._asarray(val, dtype='float32')
f = theano.function([a], b, mode=mode_with_gpu) f = theano.function([a], b, mode=mode_with_gpu)
try:
f_out = f(val)
f_caused_value_error = False
except ValueError, e:
exc = e
f_caused_value_error = True
f2 = theano.function([a], b, mode=mode_without_gpu) f2 = theano.function([a], b, mode=mode_without_gpu)
try:
f2_out = f2(val)
f2_caused_value_error = False
except ValueError, e:
exc2 = e
f2_caused_value_error = True
assert tcn.GpuCAReduce in [x.op.__class__ for x in f.maker.fgraph.toposort()] assert tcn.GpuCAReduce in [x.op.__class__ for x in f.maker.fgraph.toposort()]
assert T.CAReduce in [x.op.__class__ for x in f2.maker.fgraph.toposort()] assert T.CAReduce in [x.op.__class__ for x in f2.maker.fgraph.toposort()]
# Check that 0 shape matrices are invalid in the same cases
if f_caused_value_error != f2_caused_value_error:
if f_caused_value_error:
print 'f caused this value error:'
print exc
else:
print 'f did not raise a value error, but should have'
if f2_caused_value_error:
print 'f2 caused this value error:'
print exc2
else:
print 'f should not have raised a value error'
print 'shape was: ',shape
print 'pattern was: ',pattern
assert False
if f_caused_value_error:
continue
if val.size == 0: if val.size == 0:
assert f2(val) == f(val), ('shape', shape, 'pattern', pattern) assert f2(val) == f(val), ('shape', shape, 'pattern', pattern)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论