提交 0a532cbd authored 作者: Eric Larsen's avatar Eric Larsen 提交者: Frederic

corrections to infer_shape in Reshape and Tile

上级 14181bd4
......@@ -5366,9 +5366,9 @@ class Reshape(Op):
for i, ele in enumerate(requ):
if ele == -1:
requ[i] = missing
elif crit > 1:
raise ValueError('shape argument to Reshape.perform'
' must have at most one entry equal to -1')
elif crit > 1:
raise ValueError('shape argument to Reshape.perform'
' must have at most one entry equal to -1')
return [requ]
else:
oshape = []
......@@ -5548,9 +5548,18 @@ class Tile(Op):
def infer_shape(self, node, in_shapes):
# Note: in contrast with numpy, it is assumed that x.shape and reps
# have equal length; see alsor tile function below
# have equal length; see also tile function below
# Note: if reps were to be allowed not to be a constant and x.shape
# and reps to be unequal, the following block of code could be used:
## prepend 1 to x.shape if needed
# if self.ndim > x.ndim:
# shp = concatenate(ones(self.ndim - x.ndim), shp)
## prepend 1 to reps if needed
# reps = concatenate(ones(self.ndim - reps.shape[0]), reps)
x, reps = node.inputs
shp = x.shape
shp = in_shapes[0]
tiled_shp = shp * reps
out_shape = []
for i in range(self.ndim):
......@@ -5575,10 +5584,12 @@ def tile(x, reps, ndim=None):
TODO: expand this.
"""
if isinstance(reps, theano.tensor.TensorVariable):
raise ValueError("'reps' argument to 'tile' must be a constant (e.g. "
"tuple, list of integers)")
elif len(reps) != x.ndim:
try:
assert python_all([int(i) == i for i in iter(reps)])
except (TypeError, AssertionError):
raise ValueError("reps argument to tile must be a constant (e.g. "
"tuple, list of integers)")
if len(reps) != x.ndim:
raise ValueError("len(reps) != x.ndim not currently supported")
elif (ndim is not None) and ndim != x.ndim:
raise ValueError("if specified, ndim must be equal to both x.ndim and "
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论