提交 addbc060 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2654 from Saizheng/dev1

Replace RepeatOp with Theano graph, using tensor.tile
...@@ -647,9 +647,47 @@ def repeat(x, repeats, axis=None): ...@@ -647,9 +647,47 @@ def repeat(x, repeats, axis=None):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
return RepeatOp(axis=axis)(x, repeats) repeats = tensor.as_tensor_variable(repeats)
if repeats.ndim > 1:
raise ValueError('The dimension of repeats should not exceed 1.')
if repeats.ndim == 1:
return RepeatOp(axis=axis)(x, repeats)
else:
if axis == None:
axis = 0
x = x.flatten()
else:
if axis >= x.ndim:
raise ValueError('Axis should not exceed x.ndim-1.')
if axis < 0:
axis = x.ndim+axis
shape = [x.shape[i] for i in xrange(x.ndim)]
# shape_ is the shape of the intermediate tensor which has
# an additional dimension comparing to x. We use alloc to
# allocate space for this intermediate tensor to replicate x
# along that additional dimension.
shape_ = shape[:]
shape_.insert(axis+1, repeats)
# shape is now the shape of output, where shape[axis] becomes
# shape[axis]*repeats.
shape[axis] = shape[axis]*repeats
# dims_ is the dimension of that intermediate tensor.
dims_ = list(numpy.arange(x.ndim))
dims_.insert(axis+1, 'x')
# After the original tensor is duplicated along the additional
# dimension, we reshape it to the expected output shape, and
# return the output z.
z = tensor.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape)
return z
class Bartlett(gof.Op): class Bartlett(gof.Op):
# See function bartlett for docstring # See function bartlett for docstring
def __eq__(self, other): def __eq__(self, other):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论