提交 9c9f2166 authored 作者: lamblin's avatar lamblin

Merge pull request #1125 from abalkin/repeat-bug

Repeat Op should not copy broadcastable flag for the axis along which it...
......@@ -589,7 +589,7 @@ def get_constant_value(v):
v.owner.op.idx_list[0]]:
return numpy.asarray(1)
raise TypeError(v)
raise NotConstantError(v)
class TensorType(Type):
......
......@@ -259,10 +259,20 @@ class RepeatOp(theano.Op):
% numpy_unsupported_dtypes), repeats.dtype)
if self.axis is None:
out_type = theano.tensor.TensorType(dtype=x.dtype,
broadcastable=[False])
broadcastable=[False]
else:
out_type = x.type
try:
const_reps = basic.get_constant_value(repeats)
except basic.NotConstantError:
const_reps = None
if const_reps == 1:
broadcastable = x.broadcastable
else:
broadcastable = list(x.broadcastable)
broadcastable[self.axis] = False
out_type = theano.tensor.TensorType(x.dtype, broadcastable)
return theano.Apply(self, [x, repeats], [out_type()])
def perform(self, node, inputs, output_storage):
......
......@@ -282,6 +282,14 @@ class TestRepeatOp(utt.InferShapeTester):
for axis in self._possible_axis(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
def test_broadcastable(self):
x = T.TensorType(config.floatX, [False, True, False])()
r = RepeatOp(axis=1)(x, 2)
self.assertEqual(r.broadcastable, (False, False, False))
r = RepeatOp(axis=1)(x, 1)
self.assertEqual(r.broadcastable, (False, True, False))
r = RepeatOp(axis=0)(x, 2)
self.assertEqual(r.broadcastable, (False, True, False))
class TestBartlett(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论