提交 420b9730 authored 作者: Saizheng Zhang's avatar Saizheng Zhang

repeats function

上级 2db9cc07
......@@ -647,18 +647,34 @@ def repeat(x, repeats, axis=None):
.. versionadded:: 0.6
"""
ndim = x.ndim
shape = [x.shape[i] for i in xrange(ndim)]
if axis is None:
rep_tile = [1 for i in xrange(ndim)] + [repeats]
z = tensor.flatten(tensor.tile(x.reshape(shape+[1]), rep_tile))
if isinstance(repeats, (int, long, numpy.integer, list)):
repeats = numpy.asarray(repeats)
if repeats.ndim > 1:
raise ValueError('The dimension of repeats should not exceed 1.')
if repeats.ndim == 1:
return RepeatOp(axis=axis)(x, repeats)
else:
rep_tile = [1 for i in xrange(ndim)]
rep_tile[axis] = repeats
z = tensor.tile(x, rep_tile)
return z
if axis == None:
axis = 0
x = x.flatten()
else:
if axis >= x.ndim:
raise ValueError('Axis should not exceed x.ndim-1.')
shape = [x.shape[i] for i in xrange(x.ndim)]
shape_ = shape[:]
shape_.insert(axis+1, repeats)
shape[axis] = shape[axis]*repeats
dims = list(numpy.arange(x.ndim))
dims.insert(axis+1, 'x')
z = tensor.alloc(x.dimshuffle(*dims), *shape_).reshape(shape)
return z
class Bartlett(gof.Op):
# See function bartlett for docstring
def __eq__(self, other):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论