提交 8b2ad7df authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Refactor arg checking and processing.

This fixes a bug when the shape is a tuple containing both symbolic and numeric scalars.
上级 2e7d332f
...@@ -301,14 +301,12 @@ def _infer_ndim_bcast(ndim, shape, *args): ...@@ -301,14 +301,12 @@ def _infer_ndim_bcast(ndim, shape, *args):
else: else:
args_ndim = 0 args_ndim = 0
# there is a convention that -1 means the corresponding shape of a if isinstance(shape, (tuple, list)):
# potentially-broadcasted symbolic arg # there is a convention that -1 means the corresponding shape of a
if (isinstance(shape, (tuple, list)) # potentially-broadcasted symbolic arg
and numpy.all(numpy.asarray(shape) >= 0)): #
bcast = [(s == 1) for s in shape] # This case combines together symbolic and non-symbolic shape
v_shape = tensor.TensorConstant(type=tensor.lvector, # information
data=theano._asarray(shape,
dtype='int64'))
shape_ndim = len(shape) shape_ndim = len(shape)
if ndim is None: if ndim is None:
ndim = shape_ndim ndim = shape_ndim
...@@ -317,18 +315,7 @@ def _infer_ndim_bcast(ndim, shape, *args): ...@@ -317,18 +315,7 @@ def _infer_ndim_bcast(ndim, shape, *args):
raise ValueError('ndim should be equal to len(shape), but\n', raise ValueError('ndim should be equal to len(shape), but\n',
'ndim = %s, len(shape) = %s, shape = %s' 'ndim = %s, len(shape) = %s, shape = %s'
% (ndim, shape_ndim, shape)) % (ndim, shape_ndim, shape))
elif isinstance(shape, (tuple, list)):
# there is a convention that -1 means the corresponding shape of a
# potentially-broadcasted symbolic arg
#
# This case combines together symbolic and non-symbolic shape
# information
if ndim is None:
ndim = args_ndim
else:
ndim = max(args_ndim, ndim)
ndim = max(args_ndim, len(shape))
shape = [-1] * (ndim - len(shape)) + list(shape)
bcast = [] bcast = []
pre_v_shape = [] pre_v_shape = []
for i, s in enumerate(shape): for i, s in enumerate(shape):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论