提交 650c952d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3727 from nouiz/vector_length

[BUG] fix get_vector length. fix gh-3722
...@@ -4243,18 +4243,40 @@ def get_vector_length(v): ...@@ -4243,18 +4243,40 @@ def get_vector_length(v):
# If we take a slice, we know how many elements it will result in # If we take a slice, we know how many elements it will result in
if ((v.owner and if ((v.owner and
isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and
isinstance(v.owner.op.idx_list[0], slice))): isinstance(v.owner.op.idx_list[0], slice) and
v.owner.inputs[0].owner and
isinstance(v.owner.inputs[0].owner.op, theano.compile.ops.Shape))):
start = extract_constant(theano.tensor.subtensor.get_idx_list( start = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].start) v.owner.inputs, v.owner.op.idx_list)[0].start)
stop = extract_constant(theano.tensor.subtensor.get_idx_list( stop = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].stop) v.owner.inputs, v.owner.op.idx_list)[0].stop)
step = extract_constant(theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list)[0].step)
ndim = v.owner.inputs[0].owner.inputs[0].ndim
types = (numbers.Integral, numpy.integer)
if start is None: if start is None:
start = 0 start = 0
elif isinstance(start, types) and start < 0:
start += ndim
if start < 0:
start = 0
if stop is None: if stop is None:
stop = 0 stop = ndim
if ((isinstance(stop, numbers.Integral) and elif isinstance(stop, types):
isinstance(start, numbers.Integral))): if stop > ndim:
return stop - start stop = ndim
elif stop < 0:
stop += ndim
if step is None:
step = 1
if (isinstance(stop, types) and
isinstance(start, types) and
isinstance(step, types) and
start >= 0 and stop >= 0 and
step > 0 and stop >= start):
return (stop - start - 1) // step + 1
if isinstance(v, Variable): if isinstance(v, Variable):
msg = theano.printing.debugprint(v, file='str') msg = theano.printing.debugprint(v, file='str')
else: else:
......
...@@ -3346,6 +3346,26 @@ class T_outer(unittest.TestCase): ...@@ -3346,6 +3346,26 @@ class T_outer(unittest.TestCase):
utt.verify_grad(tensor.outer, [data0, data1]) utt.verify_grad(tensor.outer, [data0, data1])
class T_GetVectorLength(unittest.TestCase):
def test_get_vector_length(self):
x = theano.shared(numpy.zeros((2, 3, 4, 5)))
assert len(list(x.shape)) == 4
assert len(list(x.shape[2:4])) == 2
assert len(list(x.shape[2:])) == 2
assert len(list(x.shape[1:4])) == 3
assert len(list(x.shape[2:2])) == 0
assert len(list(x.shape[1:5])) == 3
assert len(list(x.shape[1:10])) == 3
# Test step
assert len(list(x.shape[1:10:2])) == 2
# Test neg start
assert len(list(x.shape[-1:4])) == 1
assert len(list(x.shape[-6:4])) == 4
# test neg stop
assert len(list(x.shape[1:-2])) == 1
assert len(list(x.shape[1:-1])) == 2
class T_Join_and_Split(unittest.TestCase): class T_Join_and_Split(unittest.TestCase):
""" """
Split is tested by each verify_grad method. Split is tested by each verify_grad method.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论