提交 b1259699 authored 作者: Sigurd Spieckermann's avatar Sigurd Spieckermann

use get_scalar_constant_value for ndim == 0 again; add error message to TypeError exception

上级 53756433
...@@ -1368,18 +1368,21 @@ def local_subtensor_make_vector(node): ...@@ -1368,18 +1368,21 @@ def local_subtensor_make_vector(node):
if isinstance(idx, (int, numpy.integer)): if isinstance(idx, (int, numpy.integer)):
return [x.owner.inputs[idx]] return [x.owner.inputs[idx]]
elif isinstance(idx, Variable): elif isinstance(idx, Variable):
# if it is a constant we can do something with it if idx.ndim == 0:
if isinstance(idx, T.Constant): # if it is a constant we can do something with it
# make sure we have an ndarray to access the `ndim` attribute try:
idx = numpy.asarray(idx.value) v = get_scalar_constant_value(idx)
if idx.ndim == 0: if isinstance(v, numpy.integer):
# Python 2.4 wants to index only with Python integers # Python 2.4 wants to index only with Python integers
return [x.owner.inputs[int(idx)]] v = int(v)
elif idx.ndim == 1: return [x.owner.inputs[v]]
values = map(int, list(idx)) except NotScalarConstantError:
return [make_vector(*[x.owner.inputs[v] for v in values])] pass
else: elif idx.ndim == 1 and isinstance(idx, T.Constant):
raise TypeError values = map(int, list(idx.value))
return [make_vector(*[x.owner.inputs[v] for v in values])]
else:
raise TypeError('case not expected')
else: else:
# it is a slice of ints and/or Variables # it is a slice of ints and/or Variables
#TODO: check subtensor to see if it can contain #TODO: check subtensor to see if it can contain
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论