提交 4017762f authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed potential bug in get_constant_value, and made it a bit safer

上级 d22c9272
...@@ -372,11 +372,22 @@ def get_constant_value(v): ...@@ -372,11 +372,22 @@ def get_constant_value(v):
return v.owner.inputs[0].data[v.owner.op.idx_list[0]] return v.owner.inputs[0].data[v.owner.op.idx_list[0]]
#Needed to make better graph in this test. #Needed to make better graph in this test.
#theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial #theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial
if v.owner.inputs[0].owner and isinstance(v.owner.inputs[0].owner.op, Join): if (v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op, Join) and
# Ensure the Join is joining only scalar variables (so that
# the constant value can be found at the same index as the one
# used in the sub-tensor).
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)):
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
assert len(v.owner.op.idx_list) == 1
# Note the '+ 1' is because the first argument to Join is the
# axis.
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]+1] ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]+1]
ret = get_constant_value(ret) ret = get_constant_value(ret)
#join can cast implicitly its input in some case. #join can cast implicitly its input in some case.
return numpy.asarray(ret, dtype=v.type.dtype) return theano._asarray(ret, dtype=v.type.dtype)
raise TypeError(v) raise TypeError(v)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论