提交 f8927951 authored 作者: Frederic Bastien's avatar Frederic Bastien

don't raise an AssertionError when it is a case that is not implemented.

上级 bdcad1da
...@@ -374,9 +374,10 @@ def get_constant_value(v): ...@@ -374,9 +374,10 @@ def get_constant_value(v):
if isinstance(v.owner.inputs[0], TensorConstant): if isinstance(v.owner.inputs[0], TensorConstant):
return v.owner.inputs[0].data.__getitem__(tuple(v.owner.op.idx_list)) return v.owner.inputs[0].data.__getitem__(tuple(v.owner.op.idx_list))
# The index list 'idx_list' should have length one # The index list 'idx_list' should have length the same shape as the
# since joining scalar variables results in a 1D vector. # input.
assert len(v.owner.op.idx_list) == 1 # TODO: implement the case where we take a scalar in a matrix
assert len(v.owner.op.idx_list) == v.owner.inputs[0].ndim
#Needed to make better graph in this test. #Needed to make better graph in this test.
#theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial #theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial
...@@ -385,7 +386,8 @@ def get_constant_value(v): ...@@ -385,7 +386,8 @@ def get_constant_value(v):
# Ensure the Join is joining only scalar variables (so that # Ensure the Join is joining only scalar variables (so that
# the constant value can be found at the same index as the one # the constant value can be found at the same index as the one
# used in the sub-tensor). # used in the sub-tensor).
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)): all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1):
# Note the '+ 1' is because the first argument to Join is the # Note the '+ 1' is because the first argument to Join is the
# axis. # axis.
...@@ -398,7 +400,8 @@ def get_constant_value(v): ...@@ -398,7 +400,8 @@ def get_constant_value(v):
theano.tensor.opt.MakeVector) and theano.tensor.opt.MakeVector) and
# MakeVector normally accept only scalar as input. # MakeVector normally accept only scalar as input.
# We put this check in case there is change in the future # We put this check in case there is change in the future
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)): all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]] ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_constant_value(ret) ret = get_constant_value(ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论