提交 f31cc48f authored 作者: Frederic's avatar Frederic

[BUG] fix get_vector length. fix gh-3722

上级 ea96b166
......@@ -4243,15 +4243,20 @@ def get_vector_length(v):
# If we take a slice, we know how many elements it will result in
if ((v.owner and
isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and
isinstance(v.owner.op.idx_list[0], slice))):
isinstance(v.owner.op.idx_list[0], slice) and
v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op, theano.compile.ops.Shape))):
start = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].start)
stop = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].stop)
if start is None:
start = 0
ndim = v.owner.inputs[0].owner.inputs[0].ndim
if stop is None:
stop = 0
stop = ndim
elif isinstance(stop, numbers.Integral) and stop > ndim :
stop = ndim
if ((isinstance(stop, numbers.Integral) and
isinstance(start, numbers.Integral))):
return stop - start
......
......@@ -3345,6 +3345,17 @@ class T_outer(unittest.TestCase):
utt.verify_grad(tensor.outer, [data0, data1])
class T_GetVectorLength(unittest.TestCase):
def test_get_vector_length(self):
x = theano.shared(numpy.zeros((2, 3, 4, 5)))
assert len(list(x.shape)) == 4
assert len(list(x.shape[2:4])) == 2
assert len(list(x.shape[2:])) == 2
assert len(list(x.shape[1:4])) == 3
assert len(list(x.shape[1:5])) == 3
assert len(list(x.shape[1:10])) == 3
class T_Join_and_Split(unittest.TestCase):
"""
Split is tested by each verify_grad method.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论