提交 8685e8cb authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix get_scalar_constant_value().

上级 9ca3be9c
......@@ -590,14 +590,23 @@ def get_scalar_constant_value(v):
v.owner.op.perform(v.owner, const, ret)
return ret[0][0]
if isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and v.ndim == 0:
# This condition depends on Subtensor always embedding constant
# indices in the Op rather than making them inputs to the Apply
# node.
if isinstance(v.owner.inputs[0], TensorConstant) and \
len(v.owner.inputs) == 1:
if isinstance(v.owner.inputs[0], TensorConstant):
indices = list(reversed(list(v.owner.inputs[1:])))
def conv(e):
if isintance(e, gof.Type):
return get_constant_scalar_value(indices.pop())
elif isinstance(e, slice):
return slice(conv(e.start),
conv(e.stop),
conv(e.step))
elif isintance(e, (int, long, numpy.integer)):
return int(e)
else:
raise NotScalarConstantError(v)
cdata = tuple(map(conv, v.owner.op.idx_list))
try:
return v.owner.inputs[0].data.__getitem__(
tuple(v.owner.op.idx_list))
return v.owner.inputs[0].data.__getitem__(cdata)
except IndexError:
raise IndexError(
str(tuple(v.owner.op.idx_list)) +
......@@ -620,10 +629,12 @@ def get_scalar_constant_value(v):
v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1):
idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(v.owner.inputs[1])
# Note the '+ 1' is because the first argument to Join is the
# axis.
ret = v.owner.inputs[0].owner.inputs[
v.owner.op.idx_list[0] + 1]
ret = v.owner.inputs[0].owner.inputs[idx + 1]
ret = get_scalar_constant_value(ret)
# join can cast implicitly its input in some case.
return theano._asarray(ret, dtype=v.type.dtype)
......@@ -636,13 +647,12 @@ def get_scalar_constant_value(v):
python_all(var.ndim == 0 for var in
v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1 and
#idx_list can contain Scalar Type object.
isinstance(v.owner.op.idx_list[0], (int, long,
numpy.integer))):
idx = v.owner.op.idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(v.owner.inputs[1])
# Python 2.4 does not support indexing with numpy.integer
# So we cast it.
idx = int(v.owner.op.idx_list[0])
idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_scalar_constant_value(ret)
# MakeVector can cast implicitly its input in some case.
......@@ -658,6 +668,8 @@ def get_scalar_constant_value(v):
op = owner.op
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, gof.Type):
idx = get_scalar_constant_value(owner.inputs[1])
grandparent = leftmost_parent.owner.inputs[0]
gp_broadcastable = grandparent.type.broadcastable
ndim = grandparent.type.ndim
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论