提交 b83d0a64 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Make get_vector_length handle simple Join Ops

上级 3a2556a8
......@@ -2387,6 +2387,16 @@ def test_get_vector_length():
assert len(list(x.shape[1:-2])) == 1
assert len(list(x.shape[1:-1])) == 2
z = join(0, as_tensor_variable(1, ndim=1), as_tensor_variable(x.shape[0], ndim=1))
assert isinstance(z.owner.op, Join)
assert get_vector_length(z) == 2
z = join(
0, as_tensor_variable([1, 2], ndim=1), as_tensor_variable(x.shape[0], ndim=1)
)
assert isinstance(z.owner.op, Join)
assert get_vector_length(z) == 3
empty_tuple = as_tensor_variable(())
assert 0 == get_vector_length(empty_tuple)
......
......@@ -2,7 +2,6 @@
import builtins
import logging
import numbers
import warnings
from collections.abc import Sequence
from functools import partial
......@@ -4938,63 +4937,51 @@ def get_vector_length(v):
):
return get_vector_length(v.owner.inputs[0])
if v.owner and isinstance(v.owner.op, Join):
axis, *arrays = v.owner.inputs
try:
axis = get_scalar_constant_value(axis)
if axis != 0:
raise ValueError()
if not builtins.all(a.ndim == 1 for a in arrays):
raise ValueError()
return builtins.sum(get_vector_length(a) for a in arrays)
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {v} cannot be determined")
# If we take a slice, we know how many elements it will result in
# TODO: We can cover more `*Subtensor` cases.
if (
v.owner
and isinstance(v.owner.op, theano.tensor.subtensor.Subtensor)
and isinstance(v.owner.op.idx_list[0], slice)
and v.owner.inputs[0].owner
and isinstance(v.owner.inputs[0].owner.op, theano.compile.ops.Shape)
):
start = extract_constant(
theano.tensor.subtensor.get_idx_list(v.owner.inputs, v.owner.op.idx_list)[
0
].start
)
stop = extract_constant(
theano.tensor.subtensor.get_idx_list(v.owner.inputs, v.owner.op.idx_list)[
0
].stop
)
step = extract_constant(
theano.tensor.subtensor.get_idx_list(v.owner.inputs, v.owner.op.idx_list)[
0
].step
)
try:
indices = theano.tensor.subtensor.get_idx_list(
v.owner.inputs, v.owner.op.idx_list
)
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
)
ndim = v.owner.inputs[0].owner.inputs[0].ndim
types = (numbers.Integral, np.integer)
if start is None:
start = 0
elif isinstance(start, types) and start < 0:
start += ndim
if start < 0:
start = 0
if stop is None:
stop = ndim
elif isinstance(stop, types):
if stop > ndim:
stop = ndim
elif stop < 0:
stop += ndim
if step is None:
step = 1
arg_len = get_vector_length(v.owner.inputs[0])
return len(range(*slice(start, stop, step).indices(arg_len)))
except (ValueError, NotScalarConstantError):
raise ValueError(f"Length of {v} cannot be determined")
if (
isinstance(stop, types)
and isinstance(start, types)
and isinstance(step, types)
and start >= 0
and stop >= 0
and step > 0
and stop >= start
):
return (stop - start - 1) // step + 1
if isinstance(v, Variable):
msg = theano.printing.debugprint(v, file="str")
else:
msg = str(v)
raise ValueError(f"length not known: {msg}")
raise ValueError(f"Length of {v} cannot be determined")
@constructor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论