提交 bee0c185 authored 作者: Saizheng Zhang's avatar Saizheng Zhang

change type-check, added comments

上级 9f8efc6c
......@@ -647,8 +647,7 @@ def repeat(x, repeats, axis=None):
.. versionadded:: 0.6
"""
if isinstance(repeats, (int, long, numpy.integer, list)):
repeats = numpy.asarray(repeats)
repeats = tensor.as_tensor_variable(repeats)
if repeats.ndim > 1:
raise ValueError('The dimension of repeats should not exceed 1.')
......@@ -666,14 +665,20 @@ def repeat(x, repeats, axis=None):
axis = x.ndim+axis
shape = [x.shape[i] for i in xrange(x.ndim)]
# shape_ is the shape of the intermediate tensor which has
# an additional dimension comparing to x. We use alloc to
# launch space for this intermediate tensor to replicate x
# along that additional dimension.
shape_ = shape[:]
shape_.insert(axis+1, repeats)
shape[axis] = shape[axis]*repeats
dims = list(numpy.arange(x.ndim))
dims.insert(axis+1, 'x')
z = tensor.alloc(x.dimshuffle(*dims), *shape_).reshape(shape)
# dims_ is the dimension of that intermediate tensor.
dims_ = list(numpy.arange(x.ndim))
dims_.insert(axis+1, 'x')
z = tensor.alloc(x.dimshuffle(*dims_), *shape_).reshape(shape)
return z
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论