提交 b83d0a64 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make get_vector_length handle simple Join Ops

上级 3a2556a8
...@@ -2387,6 +2387,16 @@ def test_get_vector_length(): ...@@ -2387,6 +2387,16 @@ def test_get_vector_length():
assert len(list(x.shape[1:-2])) == 1 assert len(list(x.shape[1:-2])) == 1
assert len(list(x.shape[1:-1])) == 2 assert len(list(x.shape[1:-1])) == 2
z = join(0, as_tensor_variable(1, ndim=1), as_tensor_variable(x.shape[0], ndim=1))
assert isinstance(z.owner.op, Join)
assert get_vector_length(z) == 2
z = join(
0, as_tensor_variable([1, 2], ndim=1), as_tensor_variable(x.shape[0], ndim=1)
)
assert isinstance(z.owner.op, Join)
assert get_vector_length(z) == 3
empty_tuple = as_tensor_variable(()) empty_tuple = as_tensor_variable(())
assert 0 == get_vector_length(empty_tuple) assert 0 == get_vector_length(empty_tuple)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import builtins import builtins
import logging import logging
import numbers
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
...@@ -4938,63 +4937,51 @@ def get_vector_length(v): ...@@ -4938,63 +4937,51 @@ def get_vector_length(v):
): ):
return get_vector_length(v.owner.inputs[0]) return get_vector_length(v.owner.inputs[0])
if v.owner and isinstance(v.owner.op, Join):
axis, *arrays = v.owner.inputs
try:
axis = get_scalar_constant_value(axis)
if axis != 0:
raise ValueError()
if not builtins.all(a.ndim == 1 for a in arrays):
raise ValueError()
return builtins.sum(get_vector_length(a) for a in arrays)
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {v} cannot be determined")
# If we take a slice, we know how many elements it will result in # If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
if ( if (
v.owner v.owner
and isinstance(v.owner.op, theano.tensor.subtensor.Subtensor) and isinstance(v.owner.op, theano.tensor.subtensor.Subtensor)
and isinstance(v.owner.op.idx_list[0], slice) and isinstance(v.owner.op.idx_list[0], slice)
and v.owner.inputs[0].owner
and isinstance(v.owner.inputs[0].owner.op, theano.compile.ops.Shape)
): ):
start = extract_constant( try:
theano.tensor.subtensor.get_idx_list(v.owner.inputs, v.owner.op.idx_list)[ indices = theano.tensor.subtensor.get_idx_list(
0 v.owner.inputs, v.owner.op.idx_list
].start )
) start = (
stop = extract_constant( None
theano.tensor.subtensor.get_idx_list(v.owner.inputs, v.owner.op.idx_list)[ if indices[0].start is None
0 else get_scalar_constant_value(indices[0].start)
].stop )
) stop = (
step = extract_constant( None
theano.tensor.subtensor.get_idx_list(v.owner.inputs, v.owner.op.idx_list)[ if indices[0].stop is None
0 else get_scalar_constant_value(indices[0].stop)
].step )
) step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
ndim = v.owner.inputs[0].owner.inputs[0].ndim arg_len = get_vector_length(v.owner.inputs[0])
types = (numbers.Integral, np.integer) return len(range(*slice(start, stop, step).indices(arg_len)))
if start is None: except (ValueError, NotScalarConstantError):
start = 0 raise ValueError(f"Length of {v} cannot be determined")
elif isinstance(start, types) and start < 0:
start += ndim
if start < 0:
start = 0
if stop is None:
stop = ndim
elif isinstance(stop, types):
if stop > ndim:
stop = ndim
elif stop < 0:
stop += ndim
if step is None:
step = 1
if ( raise ValueError(f"Length of {v} cannot be determined")
isinstance(stop, types)
and isinstance(start, types)
and isinstance(step, types)
and start >= 0
and stop >= 0
and step > 0
and stop >= start
):
return (stop - start - 1) // step + 1
if isinstance(v, Variable):
msg = theano.printing.debugprint(v, file="str")
else:
msg = str(v)
raise ValueError(f"length not known: {msg}")
@constructor @constructor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论