提交 d064f3db authored 作者: Saizheng Zhang's avatar Saizheng Zhang

tile reps supports integer, tensor.scalar and tensor.vector

上级 c8dc3dbe
...@@ -4718,16 +4718,52 @@ def tile(x, reps, ndim=None): ...@@ -4718,16 +4718,52 @@ def tile(x, reps, ndim=None):
""" """
# if reps is tensor.scalar, integer or tensor.vector, we convert it to a list.
if not isinstance(reps, (list, tuple)):
reps_astensor = as_tensor_variable(reps)
ndim_check = reps_astensor.ndim
# tensor.scalar/integer case
if ndim_check == 0:
reps_ = []
reps_.append(reps)
reps = reps_
# tesor.vector case
elif ndim_check == 1:
if ndim == None:
raise ValueError("if reps is tensor.vector, you should specify "
"the ndim")
else:
offset = ndim-reps.shape[0]
# assert that reps.shape[0] does not exceed ndim
offset = theano.tensor.opt.assert_(offset, ge(offset, 0))
reps_ = [switch(i<offset, 1, reps[i-offset]) for i in range(ndim)]
reps = reps_
#other raise error
else:
raise ValueError("the dimension of reps should not exceed 1")
else:
ndim_check = None
# if reps.ndim is less than x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x.
if len(reps) < x.ndim:
reps = [1]*(x.ndim-len(reps)) + reps
try: try:
iter(reps) iter(reps)
except TypeError: except TypeError:
raise ValueError("reps must be iterable") raise ValueError("reps must be iterable")
if not numpy.all([isinstance(r, integer_types) or
(isinstance(r, TensorVariable) and # if reps is tensor.vector, this check should not go through
r.dtype in ["int8", "int16", "int32", "int64"]) if not ndim_check == 1:
for r in reps]): if not numpy.all([isinstance(r, integer_types) or
raise ValueError("elements of reps must be scalars of integer dtype") (isinstance(r, TensorVariable) and
elif len(reps) != x.ndim: r.dtype in ["int8", "int16", "int32", "int64"])
for r in reps]):
raise ValueError("elements of reps must be scalars of integer dtype")
if len(reps) != x.ndim:
raise ValueError("len(reps) != x.ndim not currently supported") raise ValueError("len(reps) != x.ndim not currently supported")
elif (ndim is not None) and ndim != x.ndim: elif (ndim is not None) and ndim != x.ndim:
raise ValueError("if specified, ndim must be equal to both x.ndim and " raise ValueError("if specified, ndim must be equal to both x.ndim and "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论