提交 a722d38d authored 作者: Frederic Bastien's avatar Frederic Bastien

small refactoring.

上级 a3677e68
...@@ -373,6 +373,11 @@ def get_constant_value(v): ...@@ -373,6 +373,11 @@ def get_constant_value(v):
if isinstance(v.owner.op, Subtensor) and v.ndim==0: if isinstance(v.owner.op, Subtensor) and v.ndim==0:
if isinstance(v.owner.inputs[0], TensorConstant): if isinstance(v.owner.inputs[0], TensorConstant):
return v.owner.inputs[0].data[v.owner.op.idx_list[0]] return v.owner.inputs[0].data[v.owner.op.idx_list[0]]
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
assert len(v.owner.op.idx_list) == 1
#Needed to make better graph in this test. #Needed to make better graph in this test.
#theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial #theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial
if (v.owner.inputs[0].owner and if (v.owner.inputs[0].owner and
...@@ -382,9 +387,6 @@ def get_constant_value(v): ...@@ -382,9 +387,6 @@ def get_constant_value(v):
# used in the sub-tensor). # used in the sub-tensor).
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)): all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)):
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
assert len(v.owner.op.idx_list) == 1
# Note the '+ 1' is because the first argument to Join is the # Note the '+ 1' is because the first argument to Join is the
# axis. # axis.
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]+1] ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]+1]
...@@ -398,9 +400,6 @@ def get_constant_value(v): ...@@ -398,9 +400,6 @@ def get_constant_value(v):
# We put this check in case there is change in the future # We put this check in case there is change in the future
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)): all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)):
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
assert len(v.owner.op.idx_list) == 1
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]] ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_constant_value(ret) ret = get_constant_value(ret)
#MakeVector can cast implicitly its input in some case. #MakeVector can cast implicitly its input in some case.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论