提交 ace3920c authored 作者: nouiz's avatar nouiz

Merge pull request #1201 from goodfeli/shorter_error

Shorter error
...@@ -603,11 +603,33 @@ def get_scalar_constant_value(v): ...@@ -603,11 +603,33 @@ def get_scalar_constant_value(v):
# This is needed when we take the grad as the Shape op # This is needed when we take the grad as the Shape op
# are not already changed into MakeVector # are not already changed into MakeVector
if (v.owner.inputs[0].owner and owner = v.owner
isinstance(v.owner.inputs[0].owner.op, leftmost_parent = owner.inputs[0]
if (leftmost_parent.owner and
isinstance(leftmost_parent.owner.op,
theano.tensor.Shape)): theano.tensor.Shape)):
if v.owner.inputs[0].owner.inputs[0].type.broadcastable[ op = owner.op
v.owner.op.idx_list[0]]: idx_list = op.idx_list
idx = idx_list[0]
grandparent = leftmost_parent.owner.inputs[0]
gp_broadcastable = grandparent.type.broadcastable
ndim = grandparent.type.ndim
assert ndim == len(gp_broadcastable)
if not (idx < len(gp_broadcastable)):
msg = "get_scalar_constant_value detected " + \
"deterministic IndexError: x.shape[%d] " + \
"when x.ndim=%d." % (ndim, idx)
if config.exception_verbosity == 'high':
msg += 'x=%s' % min_informative_str(x)
else:
msg += 'x=%s' % str(x)
raise ValueError(msg)
if gp_broadcastable[idx]:
return numpy.asarray(1) return numpy.asarray(1)
raise NotScalarConstantError(v) raise NotScalarConstantError(v)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论