提交 b52586a8 authored 作者: Frederic Bastien's avatar Frederic Bastien

make get_constant_value dig into Subtensor when they return a scalar.

上级 8170cdf1
...@@ -362,6 +362,8 @@ def get_constant_value(v): ...@@ -362,6 +362,8 @@ def get_constant_value(v):
ret = [[None]] ret = [[None]]
v.owner.op.perform(v.owner, [const], ret) v.owner.op.perform(v.owner, [const], ret)
return ret[0][0] return ret[0][0]
if isinstance(v.owner.op, Subtensor) and v.ndim==0 and isinstance(v.owner.inputs[0], TensorConstant):
return v.owner.inputs[0].data[v.owner.op.idx_list[0]]
raise TypeError(v) raise TypeError(v)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论