提交 19ad27c9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement _get_vector_length for SpecifyShape

上级 1253e750
...@@ -498,6 +498,14 @@ class SpecifyShape(COp): ...@@ -498,6 +498,14 @@ class SpecifyShape(COp):
specify_shape = SpecifyShape() specify_shape = SpecifyShape()
@_get_vector_length.register(SpecifyShape)
def _get_vector_length_SpecifyShape(op, var):
try:
return aet.get_scalar_constant_value(var.owner.inputs[1])
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")
class Reshape(COp): class Reshape(COp):
"""Perform a reshape operation of the input x to the new shape shp. """Perform a reshape operation of the input x to the new shape shp.
The number of dimensions to which to reshape to (ndim) must be The number of dimensions to which to reshape to (ndim) must be
......
...@@ -960,11 +960,6 @@ def test_get_vector_length(): ...@@ -960,11 +960,6 @@ def test_get_vector_length():
res = get_vector_length(x) res = get_vector_length(x)
assert res == 4 assert res == 4
# Test `Shape`s
x = aesara.shared(np.zeros((2, 3, 4, 5)))
res = get_vector_length(x.shape)
assert res == 4
# Test `Subtensor`s # Test `Subtensor`s
x = as_tensor_variable(np.arange(4)) x = as_tensor_variable(np.arange(4))
assert get_vector_length(x[2:4]) == 2 assert get_vector_length(x[2:4]) == 2
......
...@@ -7,6 +7,7 @@ from aesara.compile.ops import DeepCopyOp ...@@ -7,6 +7,7 @@ from aesara.compile.ops import DeepCopyOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.tensor import get_vector_length
from aesara.tensor.basic import MakeVector, as_tensor_variable, constant from aesara.tensor.basic import MakeVector, as_tensor_variable, constant
from aesara.tensor.basic_opt import ShapeFeature from aesara.tensor.basic_opt import ShapeFeature
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
...@@ -509,3 +510,13 @@ def test_shape_i_basics(): ...@@ -509,3 +510,13 @@ def test_shape_i_basics():
with pytest.raises(TypeError): with pytest.raises(TypeError):
Shape_i(0)(scalar()) Shape_i(0)(scalar())
def test_get_vector_length():
# Test `Shape`s
x = aesara.shared(np.zeros((2, 3, 4, 5)))
assert get_vector_length(x.shape) == 4
# Test `SpecifyShape`
x = specify_shape(ivector(), (10,))
assert get_vector_length(x) == 10
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论