提交 8b2ad7df authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor arg checking and processing.

This fixes a bug when the shape is a tuple containing both symbolic and numeric scalars.
上级 2e7d332f
......@@ -301,14 +301,12 @@ def _infer_ndim_bcast(ndim, shape, *args):
else:
args_ndim = 0
# there is a convention that -1 means the corresponding shape of a
# potentially-broadcasted symbolic arg
if (isinstance(shape, (tuple, list))
and numpy.all(numpy.asarray(shape) >= 0)):
bcast = [(s == 1) for s in shape]
v_shape = tensor.TensorConstant(type=tensor.lvector,
data=theano._asarray(shape,
dtype='int64'))
if isinstance(shape, (tuple, list)):
# there is a convention that -1 means the corresponding shape of a
# potentially-broadcasted symbolic arg
#
# This case combines together symbolic and non-symbolic shape
# information
shape_ndim = len(shape)
if ndim is None:
ndim = shape_ndim
......@@ -317,18 +315,7 @@ def _infer_ndim_bcast(ndim, shape, *args):
raise ValueError('ndim should be equal to len(shape), but\n',
'ndim = %s, len(shape) = %s, shape = %s'
% (ndim, shape_ndim, shape))
elif isinstance(shape, (tuple, list)):
# there is a convention that -1 means the corresponding shape of a
# potentially-broadcasted symbolic arg
#
# This case combines together symbolic and non-symbolic shape
# information
if ndim is None:
ndim = args_ndim
else:
ndim = max(args_ndim, ndim)
ndim = max(args_ndim, len(shape))
shape = [-1] * (ndim - len(shape)) + list(shape)
bcast = []
pre_v_shape = []
for i, s in enumerate(shape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论