提交 1d6bd3c2 authored 作者: Saizheng Zhang's avatar Saizheng Zhang

add more error checks, fix minor errors

上级 0dd5090c
......@@ -4711,10 +4711,15 @@ def tile(x, reps, ndim=None):
See the docstring of `numpy.tile` for details.
Currently, x.ndim and len(reps) must be equal, and, if specified, 'ndim'
must be equal to both.
Currently, 'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),
symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())
or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
If reps is constant vector/list/symbolic vector, the length of such vector(list)
should be less or equal to x.ndim, and if specified, 'ndim' should be less or equal
to x.ndim.
TODO: expand this.
TODO: expand this so that it could support cases when the length of reps (or 'ndim')
is larger than x.ndim, which is supported by numpy.
"""
......@@ -4722,52 +4727,48 @@ def tile(x, reps, ndim=None):
if not isinstance(reps, (list, tuple)):
reps_astensor = as_tensor_variable(reps)
ndim_check = reps_astensor.ndim
if reps_astensor.dtype not in theano.tensor.discrete_dtypes:
raise ValueError("elements of reps must be integer dtype")
# tensor.scalar/integer case
if ndim_check == 0:
reps_ = []
reps_.append(reps)
reps = reps_
# tesor.vector case
reps = [reps]
# tensor.vector case
elif ndim_check == 1:
if ndim == None:
if ndim is None:
raise ValueError("if reps is tensor.vector, you should specify "
"the ndim")
else:
if ndim > x.ndim:
raise ValueError("ndim > x.ndim not currently supported")
offset = ndim-reps.shape[0]
# assert that reps.shape[0] does not exceed ndim
offset = theano.tensor.opt.assert_(offset, ge(offset, 0))
# if reps.ndim is less than x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x.
reps_ = [switch(i<offset, 1, reps[i-offset]) for i in range(ndim)]
reps = reps_
#other raise error
else:
raise ValueError("the dimension of reps should not exceed 1")
else:
ndim_check = None
if len(reps) > x.ndim:
raise ValueError("len(reps) > x.ndim not currently supported")
if not numpy.all([isinstance(r, (int, long)) or
(isinstance(r, TensorVariable) and
r.dtype in theano.tensor.discrete_dtypes) for r in reps]):
raise ValueError("elements of reps must be scalars of integer dtype")
# if reps.ndim is less than x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x.
# "1" so that reps will have the same ndim as x.
reps = list(reps)
if len(reps) < x.ndim:
reps = [1]*(x.ndim-len(reps)) + reps
try:
iter(reps)
except TypeError:
raise ValueError("reps must be iterable")
# if reps is tensor.vector, this check should not go through
if not ndim_check == 1:
if not numpy.all([isinstance(r, integer_types) or
(isinstance(r, TensorVariable) and
r.dtype in ["int8", "int16", "int32", "int64"])
for r in reps]):
raise ValueError("elements of reps must be scalars of integer dtype")
if len(reps) != x.ndim:
raise ValueError("len(reps) != x.ndim not currently supported")
elif (ndim is not None) and ndim != x.ndim:
raise ValueError("if specified, ndim must be equal to both x.ndim and "
"len(reps)")
reps = [1]*(x.ndim-len(reps)) + reps
if ndim is None:
ndim = len(reps)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论