提交 9ca8c709 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add assert about the shape in the graph.

fix gh-1738
上级 f3f97d77
......@@ -370,10 +370,16 @@ def local_0_dot_x(node):
if replace:
constant_zero = T.constant(0, dtype=node.outputs[0].type.dtype)
if x.ndim == 2 and y.ndim == 2:
constant_zero = assert_(constant_zero,
T.eq(x.shape[1], y.shape[0]))
return [T.alloc(constant_zero, x.shape[0], y.shape[1])]
elif x.ndim == 1 and y.ndim == 2:
constant_zero = assert_(constant_zero,
T.eq(x.shape[0], y.shape[0]))
return [T.alloc(constant_zero, y.shape[1])]
elif x.ndim == 2 and y.ndim == 1:
constant_zero = assert_(constant_zero,
T.eq(x.shape[1], y.shape[1]))
return [T.alloc(constant_zero, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
return [constant_zero]
......
......@@ -2395,6 +2395,16 @@ class Test_alloc_zero(unittest.TestCase):
assert numpy.all([ not isinstance(x.op, tensor.Dot) for x in
f.maker.fgraph.toposort() ])
def test_dot_allocs_0_err(self):
#test that we don't remove errors
m1 = tensor.matrix('m1')
vm = numpy.asarray([[1, 2, 3], [4, 5, 6]],
dtype=theano.config.floatX)
z = numpy.zeros((3, 3), dtype=theano.config.floatX)
o = tensor.dot(z, m1)
f = theano.function([m1], o, mode=self.mode)
self.assertRaises((ValueError, AssertionError), f, vm)
def test_local_subtensor_of_alloc():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论