提交 1d6bd3c2 authored 作者: Saizheng Zhang's avatar Saizheng Zhang

add more error checks, fix minor errors

上级 0dd5090c
...@@ -4711,10 +4711,15 @@ def tile(x, reps, ndim=None): ...@@ -4711,10 +4711,15 @@ def tile(x, reps, ndim=None):
See the docstring of `numpy.tile` for details. See the docstring of `numpy.tile` for details.
Currently, x.ndim and len(reps) must be equal, and, if specified, 'ndim' Currently, 'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),
must be equal to both. symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())
or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
If reps is constant vector/list/symbolic vector, the length of such vector(list)
should be less or equal to x.ndim, and if specified, 'ndim' should be less or equal
to x.ndim.
TODO: expand this. TODO: expand this so that it could support cases when the length of reps (or 'ndim')
is larger than x.ndim, which is supported by numpy.
""" """
...@@ -4722,53 +4727,49 @@ def tile(x, reps, ndim=None): ...@@ -4722,53 +4727,49 @@ def tile(x, reps, ndim=None):
if not isinstance(reps, (list, tuple)): if not isinstance(reps, (list, tuple)):
reps_astensor = as_tensor_variable(reps) reps_astensor = as_tensor_variable(reps)
ndim_check = reps_astensor.ndim ndim_check = reps_astensor.ndim
if reps_astensor.dtype not in theano.tensor.discrete_dtypes:
raise ValueError("elements of reps must be integer dtype")
# tensor.scalar/integer case # tensor.scalar/integer case
if ndim_check == 0: if ndim_check == 0:
reps_ = [] reps = [reps]
reps_.append(reps)
reps = reps_ # tensor.vector case
# tesor.vector case
elif ndim_check == 1: elif ndim_check == 1:
if ndim == None: if ndim is None:
raise ValueError("if reps is tensor.vector, you should specify " raise ValueError("if reps is tensor.vector, you should specify "
"the ndim") "the ndim")
else: else:
if ndim > x.ndim:
raise ValueError("ndim > x.ndim not currently supported")
offset = ndim-reps.shape[0] offset = ndim-reps.shape[0]
# assert that reps.shape[0] does not exceed ndim # assert that reps.shape[0] does not exceed ndim
offset = theano.tensor.opt.assert_(offset, ge(offset, 0)) offset = theano.tensor.opt.assert_(offset, ge(offset, 0))
# if reps.ndim is less than x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x.
reps_ = [switch(i<offset, 1, reps[i-offset]) for i in range(ndim)] reps_ = [switch(i<offset, 1, reps[i-offset]) for i in range(ndim)]
reps = reps_ reps = reps_
#other raise error #other raise error
else: else:
raise ValueError("the dimension of reps should not exceed 1") raise ValueError("the dimension of reps should not exceed 1")
else: else:
ndim_check = None if len(reps) > x.ndim:
raise ValueError("len(reps) > x.ndim not currently supported")
if not numpy.all([isinstance(r, (int, long)) or
(isinstance(r, TensorVariable) and
r.dtype in theano.tensor.discrete_dtypes) for r in reps]):
raise ValueError("elements of reps must be scalars of integer dtype")
# if reps.ndim is less than x.ndim, we pad the reps with # if reps.ndim is less than x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x. # "1" so that reps will have the same ndim as x.
reps = list(reps)
if len(reps) < x.ndim: if len(reps) < x.ndim:
reps = [1]*(x.ndim-len(reps)) + reps reps = [1]*(x.ndim-len(reps)) + reps
try:
iter(reps)
except TypeError:
raise ValueError("reps must be iterable")
# if reps is tensor.vector, this check should not go through
if not ndim_check == 1:
if not numpy.all([isinstance(r, integer_types) or
(isinstance(r, TensorVariable) and
r.dtype in ["int8", "int16", "int32", "int64"])
for r in reps]):
raise ValueError("elements of reps must be scalars of integer dtype")
if len(reps) != x.ndim:
raise ValueError("len(reps) != x.ndim not currently supported")
elif (ndim is not None) and ndim != x.ndim:
raise ValueError("if specified, ndim must be equal to both x.ndim and "
"len(reps)")
if ndim is None: if ndim is None:
ndim = len(reps) ndim = len(reps)
reps = list(reps) reps = list(reps)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论