提交 f8927951 authored 作者: Frederic Bastien's avatar Frederic Bastien

don't raise an AssertionError when it is a case that is not implemented.

上级 bdcad1da
......@@ -374,9 +374,10 @@ def get_constant_value(v):
if isinstance(v.owner.inputs[0], TensorConstant):
return v.owner.inputs[0].data.__getitem__(tuple(v.owner.op.idx_list))
# The index list 'idx_list' should have length one
# since joining scalar variables results in a 1D vector.
assert len(v.owner.op.idx_list) == 1
# The index list 'idx_list' should have length the same shape as the
# input.
# TODO: implement the case where we take a scalar in a matrix
assert len(v.owner.op.idx_list) == v.owner.inputs[0].ndim
#Needed to make better graph in this test.
#theano/tensor/tests/test_sharedvar.py:test_shared_options.test_specify_shape_partial
......@@ -385,7 +386,8 @@ def get_constant_value(v):
# Ensure the Join is joining only scalar variables (so that
# the constant value can be found at the same index as the one
# used in the sub-tensor).
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)):
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1):
# Note the '+ 1' is because the first argument to Join is the
# axis.
......@@ -398,7 +400,8 @@ def get_constant_value(v):
theano.tensor.opt.MakeVector) and
# MakeVector normally accept only scalar as input.
# We put this check in case there is change in the future
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs)):
all(var.ndim==0 for var in v.owner.inputs[0].owner.inputs) and
len(v.owner.op.idx_list) == 1):
ret = v.owner.inputs[0].owner.inputs[v.owner.op.idx_list[0]]
ret = get_constant_value(ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论