提交 b51cfd09 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Ignore univariate Elemwise Ops in get_vector_length

This allows one to obtain the length of a fixed-length vector that has--for example--been cast to a different datatype, squared, etc.
上级 ca215a2d
......@@ -108,6 +108,7 @@ from theano.tensor import (
as_tensor_variable,
batched_dot,
bvector,
cast,
choose,
clip,
constant,
......@@ -2368,24 +2369,43 @@ class TestOuter:
utt.verify_grad(tt.outer, [data0, data1])
class TestGetVectorLength:
def test_get_vector_length(self):
x = theano.shared(np.zeros((2, 3, 4, 5)))
assert len(list(x.shape)) == 4
assert len(list(x.shape[2:4])) == 2
assert len(list(x.shape[2:])) == 2
assert len(list(x.shape[1:4])) == 3
assert len(list(x.shape[2:2])) == 0
assert len(list(x.shape[1:5])) == 3
assert len(list(x.shape[1:10])) == 3
# Test step
assert len(list(x.shape[1:10:2])) == 2
# Test neg start
assert len(list(x.shape[-1:4])) == 1
assert len(list(x.shape[-6:4])) == 4
# test neg stop
assert len(list(x.shape[1:-2])) == 1
assert len(list(x.shape[1:-1])) == 2
def test_get_vector_length():
x = theano.shared(np.zeros((2, 3, 4, 5)))
assert len(list(x.shape)) == 4
assert len(list(x.shape[2:4])) == 2
assert len(list(x.shape[2:])) == 2
assert len(list(x.shape[1:4])) == 3
assert len(list(x.shape[2:2])) == 0
assert len(list(x.shape[1:5])) == 3
assert len(list(x.shape[1:10])) == 3
# Test step
assert len(list(x.shape[1:10:2])) == 2
# Test neg start
assert len(list(x.shape[-1:4])) == 1
assert len(list(x.shape[-6:4])) == 4
# test neg stop
assert len(list(x.shape[1:-2])) == 1
assert len(list(x.shape[1:-1])) == 2
empty_tuple = as_tensor_variable(())
assert 0 == get_vector_length(empty_tuple)
x = lscalar("x")
y = dscalar("y")
triple = as_tensor_variable((x, y, 9.0))
assert 3 == get_vector_length(triple)
triple = cast(as_tensor_variable((x, y, 9.0)), "int64")
assert 3 == get_vector_length(triple)
a, b, c = triple
mode = theano.compile.get_default_mode().excluding("constant_folding")
f = function([x, y], [b, c, a], mode=mode)
topo = f.maker.fgraph.toposort()
assert [True for node in topo if isinstance(node.op, opt.MakeVector)]
assert np.allclose(f(4, 5), [5, 9, 4])
class TestJoinAndSplit:
......@@ -2865,20 +2885,6 @@ class TestJoinAndSplit:
utt.verify_grad(lambda a, b: join(-1, a, b), [v, 2 * v], mode=self.mode)
def test_vector_len(self):
x = lscalar("x")
y = dscalar("y")
triple = as_tensor_variable((x, y, 9.0))
assert 3 == get_vector_length(triple)
a, b, c = triple
f = function([x, y], [b, c, a], mode=self.mode)
topo = f.maker.fgraph.toposort()
assert [True for node in topo if isinstance(node.op, opt.MakeVector)]
assert np.allclose(f(4, 5), [5, 9, 4])
def test_broadcastable_flag_assignment_mixed_otheraxes(self):
# Test that the broadcastable flags for the output of
# a join operation on non-join axes are True if one or
......
......@@ -4924,6 +4924,17 @@ def get_vector_length(v):
return len(v.owner.inputs)
if v.owner and isinstance(v.owner.op, Shape):
return v.owner.inputs[0].type.ndim
# We can skip `Op`s that don't affect the length, like unary `Elemwise`
# `Op`s
if (
v.owner
and isinstance(v.owner.op, theano.tensor.elemwise.Elemwise)
and len(v.owner.inputs) == 1
and len(v.owner.outputs) == 1
):
return get_vector_length(v.owner.inputs[0])
# If we take a slice, we know how many elements it will result in
if (
v.owner
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论