提交 a722d38d authored 作者: Frederic Bastien's avatar Frederic Bastien

small refactoring.

上级 a3677e68
......@@ -373,6 +373,11 @@ def get_constant_value(v):
if isinstance(v.owner.op, Subtensor) and v.ndim==0:
if isinstance(v.owner.inputs[0], TensorConstant):
return v.owner.inputs[0].data[v.owner.op.idx_list[0]]
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
assert len(v.owner.op.idx_list) == 1
#Needed to make better graph in this test.
#theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial
if (v.owner.inputs[0].owner and
......@@ -382,9 +387,6 @@ def get_constant_value(v):
# used in the sub-tensor).
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)):
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
assert len(v.owner.op.idx_list) == 1
# Note the '+ 1' is because the first argument to Join is the
# axis.
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]+1]
......@@ -398,9 +400,6 @@ def get_constant_value(v):
# We put this check in case there is change in the future
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)):
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
assert len(v.owner.op.idx_list) == 1
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_constant_value(ret)
#MakeVector can cast implicitly its input in some case.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论