提交 7d30ab3c authored 作者: abergeron's avatar abergeron

Merge pull request #1740 from nouiz/shape_err

Add assert about the shape in the graph.
......@@ -370,12 +370,20 @@ def local_0_dot_x(node):
if replace:
constant_zero = T.constant(0, dtype=node.outputs[0].type.dtype)
if x.ndim == 2 and y.ndim == 2:
constant_zero = assert_(constant_zero,
T.eq(x.shape[1], y.shape[0]))
return [T.alloc(constant_zero, x.shape[0], y.shape[1])]
elif x.ndim == 1 and y.ndim == 2:
constant_zero = assert_(constant_zero,
T.eq(x.shape[0], y.shape[0]))
return [T.alloc(constant_zero, y.shape[1])]
elif x.ndim == 2 and y.ndim == 1:
constant_zero = assert_(constant_zero,
T.eq(x.shape[1], y.shape[0]))
return [T.alloc(constant_zero, x.shape[0])]
elif x.ndim == 1 and y.ndim == 1:
constant_zero = assert_(constant_zero,
T.eq(x.shape[0], y.shape[0]))
return [constant_zero]
else:
_logger.warning("Optimization Warning: "
......
......@@ -2377,11 +2377,14 @@ class Test_alloc_zero(unittest.TestCase):
v2 = tensor.vector('v2')
m1 = tensor.matrix('m1')
m2 = tensor.matrix('m2')
vv = numpy.asarray([0, 1, 2], dtype=theano.config.floatX)
vm = numpy.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=theano.config.floatX)
for _e1 in [(v1, vv), (m1, vm)]:
for _e2 in [(v2, vv), (m2, vm)]:
vv2 = numpy.asarray([0, 1], dtype=theano.config.floatX)
vm2 = numpy.asarray([[1, 2], [4, 5]],
dtype=theano.config.floatX)
vv3 = numpy.asarray([0, 1, 2], dtype=theano.config.floatX)
vm3 = numpy.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=theano.config.floatX)
for _e1 in [(v1, vv2, vv3), (m1, vm2, vm3)]:
for _e2 in [(v2, vv2, vv3), (m2, vm2, vm3)]:
for p in [0, 1]:
if p == 0:
e1 = tensor.zeros_like(_e1[0])
......@@ -2392,9 +2395,16 @@ class Test_alloc_zero(unittest.TestCase):
o = tensor.dot(e1, e2)
f = theano.function([_e1[0], _e2[0]], o, mode=self.mode)
f(_e1[1], _e2[1])
assert numpy.all([ not isinstance(x.op, tensor.Dot) for x in
f(_e1[2], _e2[2])
assert numpy.all([not isinstance(x.op, tensor.Dot) for x in
f.maker.fgraph.toposort() ])
#test that we don't remove shape errors
self.assertRaises((ValueError, AssertionError), f,
_e1[1], _e2[2])
self.assertRaises((ValueError, AssertionError), f,
_e1[2], _e2[1])
def test_local_subtensor_of_alloc():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论