提交 582e390a authored 作者: Frederic's avatar Frederic

[BUG] Fix bug introduced by gh-2195 merged yesterday.

This relax a condition in DimShuffle to help reuse it. We where removing the DimShuffle, it was working when it was only adding outer dimensions, as the Elemwise would readd it. But if it did other stuff, we where loosing what it did.
上级 9e439d5f
......@@ -182,10 +182,20 @@ class DimShuffle(Op):
input = as_tensor_variable(_input)
ib = tuple(input.type.broadcastable)
if not ib == self.input_broadcastable:
raise TypeError((
"The number of dimensions and/or broadcastable pattern of the "
"input is incorrect for this op. Expected %s, got %s."
% (self.input_broadcastable, ib)))
if len(ib) != len(self.input_broadcastable):
raise TypeError((
"The number of dimensions of the "
"input is incorrect for this op. Expected %s, got %s."
% (self.input_broadcastable, ib)))
for expected, b in zip(self.input_broadcastable, ib):
if expected is True and b is False:
raise TypeError((
"The broadcastable pattern of the "
"input is incorrect for this op. Expected %s, got %s."
% (self.input_broadcastable, ib)))
#else, expected == b or expected is False and b is True
# Both case are good.
ob = []
for value in self.new_order:
if value == 'x':
......
......@@ -1639,7 +1639,11 @@ def local_alloc_elemwise(node):
*[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]])
new_i.append(i.owner.inputs[0].owner.inputs[0])
alloc_input = i.owner.inputs[0].owner.inputs[0]
assert alloc_input.ndim == i.owner.inputs[0].ndim
# We need to keep the dimshuffle. It could swap axes or
# add dimensions anywhere.
new_i.append(i.owner.op(alloc_input))
else:
new_i.append(i)
new_i[assert_op_idx] = assert_op
......
......@@ -2687,6 +2687,21 @@ class Test_local_alloc_elemwise(unittest.TestCase):
self._verify_alloc_count(func, 1)
self._verify_assert_count(func, 1)
def test_error(self):
t3fft = theano.tensor.tensor(dtype=self.dtype,
broadcastable=(False, False, True))
row = theano.tensor.row(dtype=self.dtype)
o = T.alloc(row, 5, 5).dimshuffle(0, 1, 'x') + t3fft
func = function(
[t3fft, row],
o,
mode='FAST_RUN'
)
self._verify_alloc_count(func, 0)
self._verify_assert_count(func, 1)
d = numpy.random.rand(5, 5, 1).astype(self.dtype)
r = numpy.random.rand(1, 5).astype(self.dtype)
func(d, r)
def test_local_subtensor_of_alloc():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论