提交 420b9730 authored 作者: Saizheng Zhang's avatar Saizheng Zhang

repeats function

上级 2db9cc07
...@@ -647,18 +647,34 @@ def repeat(x, repeats, axis=None): ...@@ -647,18 +647,34 @@ def repeat(x, repeats, axis=None):
.. versionadded:: 0.6 .. versionadded:: 0.6
""" """
ndim = x.ndim if isinstance(repeats, (int, long, numpy.integer, list)):
shape = [x.shape[i] for i in xrange(ndim)] repeats = numpy.asarray(repeats)
if axis is None:
rep_tile = [1 for i in xrange(ndim)] + [repeats] if repeats.ndim > 1:
z = tensor.flatten(tensor.tile(x.reshape(shape+[1]), rep_tile)) raise ValueError('The dimension of repeats should not exceed 1.')
if repeats.ndim == 1:
return RepeatOp(axis=axis)(x, repeats)
else: else:
rep_tile = [1 for i in xrange(ndim)] if axis == None:
rep_tile[axis] = repeats axis = 0
z = tensor.tile(x, rep_tile) x = x.flatten()
return z else:
if axis >= x.ndim:
raise ValueError('Axis should not exceed x.ndim-1.')
shape = [x.shape[i] for i in xrange(x.ndim)]
shape_ = shape[:]
shape_.insert(axis+1, repeats)
shape[axis] = shape[axis]*repeats
dims = list(numpy.arange(x.ndim))
dims.insert(axis+1, 'x')
z = tensor.alloc(x.dimshuffle(*dims), *shape_).reshape(shape)
return z
class Bartlett(gof.Op): class Bartlett(gof.Op):
# See function bartlett for docstring # See function bartlett for docstring
def __eq__(self, other): def __eq__(self, other):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论