提交 848fa7b5 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

fixed bug in get_constant_value's handling of Subtensor

上级 c3b5c808
...@@ -530,9 +530,17 @@ def get_constant_value(v): ...@@ -530,9 +530,17 @@ def get_constant_value(v):
v.owner.op.perform(v.owner, [const], ret) v.owner.op.perform(v.owner, [const], ret)
return ret[0][0] return ret[0][0]
if isinstance(v.owner.op, Subtensor) and v.ndim == 0: if isinstance(v.owner.op, Subtensor) and v.ndim == 0:
if isinstance(v.owner.inputs[0], TensorConstant): # This condition depends on Subtensor always embedding constant
return v.owner.inputs[0].data.__getitem__( # indices in the Op rather than making them inputs to the Apply node
if isinstance(v.owner.inputs[0], TensorConstant) and \
len(v.owner.inputs) == 1:
try:
return v.owner.inputs[0].data.__getitem__(
tuple(v.owner.op.idx_list)) tuple(v.owner.op.idx_list))
except IndexError:
raise IndexError(str(tuple(v.owner.op.idx_list))+" is not a valid index into " + \
str(v.owner.inputs[0].data))
# The index list 'idx_list' should have length the same # The index list 'idx_list' should have length the same
# shape as the input. # shape as the input.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论