提交 1949bf69 authored 作者: Saizheng Zhang's avatar Saizheng Zhang

more changes

上级 1d6bd3c2
"""A `Type` and `Op` classes to work with numpy.ndarrays symbolically."""
import __builtin__
import sys
import warnings
......@@ -4711,18 +4712,20 @@ def tile(x, reps, ndim=None):
See the docstring of `numpy.tile` for details.
Currently, 'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),
'reps' can be constant integer (e.g. 3), constant vector(e.g. [2 3]),
symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector())
or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]).
If reps is constant vector/list/symbolic vector, the length of such vector(list)
should be less or equal to x.ndim, and if specified, 'ndim' should be less or equal
to x.ndim.
TODO: expand this so that it could support cases when the length of reps (or 'ndim')
is larger than x.ndim, which is supported by numpy.
ndim is the number of the dimensions of the output, if it is provided, ndim
should be equal or larger than x.ndim and len(reps), otherwise, we will use
max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to
be provided.
"""
if ndim is not None and ndim < x.ndim:
raise ValueError("ndim should be equal or larger than x.ndim")
# if reps is tensor.scalar, integer or tensor.vector, we convert it to a list.
if not isinstance(reps, (list, tuple)):
reps_astensor = as_tensor_variable(reps)
......@@ -4740,9 +4743,6 @@ def tile(x, reps, ndim=None):
raise ValueError("if reps is tensor.vector, you should specify "
"the ndim")
else:
if ndim > x.ndim:
raise ValueError("ndim > x.ndim not currently supported")
offset = ndim-reps.shape[0]
# assert that reps.shape[0] does not exceed ndim
......@@ -4757,22 +4757,21 @@ def tile(x, reps, ndim=None):
else:
raise ValueError("the dimension of reps should not exceed 1")
else:
if len(reps) > x.ndim:
raise ValueError("len(reps) > x.ndim not currently supported")
if ndim is not None and len(reps) > ndim:
raise ValueError("len(reps) should be equal or less than ndim")
if not numpy.all([isinstance(r, (int, long)) or
(isinstance(r, TensorVariable) and
r.dtype in theano.tensor.discrete_dtypes) for r in reps]):
raise ValueError("elements of reps must be scalars of integer dtype")
# if reps.ndim is less than x.ndim, we pad the reps with
# "1" so that reps will have the same ndim as x.
reps = list(reps)
if len(reps) < x.ndim:
reps = [1]*(x.ndim-len(reps)) + reps
if ndim is None:
ndim = len(reps)
reps = list(reps)
if ndim is None:
ndim = __builtin__.max(len(reps), x.ndim)
if len(reps) < ndim:
reps = [1]*(ndim-len(reps)) + reps
shape = [x.shape[i] for i in xrange(ndim)]
alloc_shape = reps + shape
y = alloc(x, *alloc_shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论