提交 9c9f2166 authored 作者: lamblin's avatar lamblin

Merge pull request #1125 from abalkin/repeat-bug

Repeat Op should not copy broadcastable flag for the axis along which it...
...@@ -589,7 +589,7 @@ def get_constant_value(v): ...@@ -589,7 +589,7 @@ def get_constant_value(v):
v.owner.op.idx_list[0]]: v.owner.op.idx_list[0]]:
return numpy.asarray(1) return numpy.asarray(1)
raise TypeError(v) raise NotConstantError(v)
class TensorType(Type): class TensorType(Type):
......
...@@ -259,10 +259,20 @@ class RepeatOp(theano.Op): ...@@ -259,10 +259,20 @@ class RepeatOp(theano.Op):
% numpy_unsupported_dtypes), repeats.dtype) % numpy_unsupported_dtypes), repeats.dtype)
if self.axis is None: if self.axis is None:
out_type = theano.tensor.TensorType(dtype=x.dtype, broadcastable=[False]
broadcastable=[False])
else: else:
out_type = x.type try:
const_reps = basic.get_constant_value(repeats)
except basic.NotConstantError:
const_reps = None
if const_reps == 1:
broadcastable = x.broadcastable
else:
broadcastable = list(x.broadcastable)
broadcastable[self.axis] = False
out_type = theano.tensor.TensorType(x.dtype, broadcastable)
return theano.Apply(self, [x, repeats], [out_type()]) return theano.Apply(self, [x, repeats], [out_type()])
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
......
...@@ -282,6 +282,14 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -282,6 +282,14 @@ class TestRepeatOp(utt.InferShapeTester):
for axis in self._possible_axis(ndim): for axis in self._possible_axis(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a]) utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
def test_broadcastable(self):
x = T.TensorType(config.floatX, [False, True, False])()
r = RepeatOp(axis=1)(x, 2)
self.assertEqual(r.broadcastable, (False, False, False))
r = RepeatOp(axis=1)(x, 1)
self.assertEqual(r.broadcastable, (False, True, False))
r = RepeatOp(axis=0)(x, 2)
self.assertEqual(r.broadcastable, (False, True, False))
class TestBartlett(utt.InferShapeTester): class TestBartlett(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论