提交 289eb329 authored 作者: Saizheng Zhang's avatar Saizheng Zhang

support broadcastable tensors in repeats

上级 c8dc3dbe
......@@ -732,9 +732,12 @@ def repeat(x, repeats, axis=None):
if repeats.ndim > 1:
raise ValueError('The dimension of repeats should not exceed 1.')
if repeats.ndim == 1:
return RepeatOp(axis=axis)(x, repeats)
if repeats.ndim == 1 and not repeats.broadcastable[0]:
return RepeatOp(axis=axis)(x, repeats)
else:
if repeats.ndim == 1:
repeats = repeats[0]
if axis is None:
axis = 0
x = x.flatten()
......
......@@ -445,6 +445,26 @@ class TestRepeatOp(utt.InferShapeTester):
assert np.allclose(np.repeat(a, r, axis=axis),
f(a, r))
#check when r is a list of single integer, e.g. [3].
r = np.random.random_integers(10, size=()).astype(dtype) + 2
f = theano.function([x],
repeat(x, [r], axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis),
f(a))
assert not np.any([isinstance(n.op, RepeatOp)
for n in f.maker.fgraph.toposort()])
# check when r is theano tensortype that broadcastable is (True,)
r_var = theano.tensor.TensorType(broadcastable=(True,),
dtype=dtype)()
r = np.random.random_integers(5, size=(1,)).astype(dtype)
f = theano.function([x, r_var],
repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r[0], axis=axis),
f(a, r))
assert not np.any([isinstance(n.op, RepeatOp)
for n in f.maker.fgraph.toposort()])
@attr('slow')
def test_infer_shape(self):
for ndim in range(4):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论