提交 6580711f authored 作者: abalkin's avatar abalkin

Added logic to preserve broadcastable flag if repeats=1. Added unit test.

上级 c9dec33e
......@@ -261,8 +261,15 @@ class RepeatOp(theano.Op):
if self.axis is None:
broadcastable=[False]
else:
broadcastable = list(x.broadcastable)
broadcastable[self.axis] = False
try:
const_reps = basic.get_constant_value(repeats)
except TypeError:
const_reps = None
if const_reps == 1:
broadcastable = x.broadcastable
else:
broadcastable = list(x.broadcastable)
broadcastable[self.axis] = False
out_type = theano.tensor.TensorType(x.dtype, broadcastable)
......
......@@ -282,6 +282,14 @@ class TestRepeatOp(utt.InferShapeTester):
for axis in self._possible_axis(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
def test_broadcastable(self):
x = T.TensorType(config.floatX, [False, True, False])()
r = RepeatOp(axis=1)(x, 2)
self.assertEqual(r.broadcastable, (False, False, False))
r = RepeatOp(axis=1)(x, 1)
self.assertEqual(r.broadcastable, (False, True, False))
r = RepeatOp(axis=0)(x, 2)
self.assertEqual(r.broadcastable, (False, True, False))
class TestBartlett(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论