提交 5a378b80 authored 作者: Frederic Bastien's avatar Frederic Bastien

make get_constant_value return constant in one more case needed for a futur test.

上级 24f06dd1
...@@ -367,8 +367,16 @@ def get_constant_value(v): ...@@ -367,8 +367,16 @@ def get_constant_value(v):
ret = [[None]] ret = [[None]]
v.owner.op.perform(v.owner, [const], ret) v.owner.op.perform(v.owner, [const], ret)
return ret[0][0] return ret[0][0]
if isinstance(v.owner.op, Subtensor) and v.ndim==0 and isinstance(v.owner.inputs[0], TensorConstant): if isinstance(v.owner.op, Subtensor) and v.ndim==0:
return v.owner.inputs[0].data[v.owner.op.idx_list[0]] if isinstance(v.owner.inputs[0], TensorConstant):
return v.owner.inputs[0].data[v.owner.op.idx_list[0]]
#Needed to make better graph in this test.
#theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial
if v.owner.inputs[0].owner and isinstance(v.owner.inputs[0].owner.op, Join):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]+1]
ret = get_constant_value(ret)
#join can cast implicitly its input in some case.
return numpy.asarray(ret, dtype=v.type.dtype)
raise TypeError(v) raise TypeError(v)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论