提交 f1dc0897 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make shape_tuple return all static shape information

上级 456cce1d
...@@ -8,6 +8,7 @@ import numpy as np ...@@ -8,6 +8,7 @@ import numpy as np
import aesara import aesara
from aesara.gradient import DisconnectedType from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.type import HasShape
from aesara.link.c.op import COp from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
...@@ -158,18 +159,28 @@ def _get_vector_length_Shape(op, var): ...@@ -158,18 +159,28 @@ def _get_vector_length_Shape(op, var):
def shape_tuple(x: TensorVariable) -> Tuple[Variable, ...]: def shape_tuple(x: TensorVariable) -> Tuple[Variable, ...]:
"""Get a tuple of symbolic shape values. r"""Get a tuple of symbolic shape values.
This will return `ScalarConstant`\s for static shape values.
This will return a `ScalarConstant` with the value ``1`` wherever
broadcastable is ``True``.
""" """
one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1) if not isinstance(x.type, HasShape):
return tuple( # We assume/call it a scalar
one_at if getattr(sh, "value", sh) == 1 or bcast else sh return ()
for sh, bcast in zip(
shape(x), getattr(x, "broadcastable", (False,) * x.type.ndim) res = ()
) symbolic_shape = shape(x)
) static_shape = x.type.shape
for i in range(x.type.ndim):
shape_val = static_shape[i]
if shape_val is not None:
# TODO: Why not use uint64?
res += (aesara.scalar.ScalarConstant(aesara.scalar.int64, shape_val),)
else:
res += (symbolic_shape[i],)
return res
class Shape_i(COp): class Shape_i(COp):
......
...@@ -9,6 +9,7 @@ from aesara.graph.basic import Variable ...@@ -9,6 +9,7 @@ from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scalar.basic import ScalarConstant
from aesara.tensor import as_tensor_variable, get_vector_length, row from aesara.tensor import as_tensor_variable, get_vector_length, row
from aesara.tensor.basic import MakeVector, constant from aesara.tensor.basic import MakeVector, constant
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
...@@ -22,6 +23,7 @@ from aesara.tensor.shape import ( ...@@ -22,6 +23,7 @@ from aesara.tensor.shape import (
reshape, reshape,
shape, shape,
shape_i, shape_i,
shape_tuple,
specify_broadcastable, specify_broadcastable,
specify_shape, specify_shape,
unbroadcast, unbroadcast,
...@@ -46,6 +48,7 @@ from aesara.tensor.type_other import NoneConst ...@@ -46,6 +48,7 @@ from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable from aesara.tensor.var import TensorVariable
from aesara.typed_list import make_list from aesara.typed_list import make_list
from tests import unittest_tools as utt from tests import unittest_tools as utt
from tests.graph.utils import MyType2
from tests.tensor.utils import eval_outputs, random from tests.tensor.utils import eval_outputs, random
from tests.test_rop import RopLopChecker from tests.test_rop import RopLopChecker
...@@ -657,3 +660,18 @@ class TestUnbroadcastInferShape(utt.InferShapeTester): ...@@ -657,3 +660,18 @@ class TestUnbroadcastInferShape(utt.InferShapeTester):
Unbroadcast, Unbroadcast,
warn=False, warn=False,
) )
def test_shape_tuple():
x = Variable(MyType2(), None, None)
assert shape_tuple(x) == ()
x = tensor(np.float64, shape=(1, 2, None))
res = shape_tuple(x)
assert isinstance(res, tuple)
assert isinstance(res[0], ScalarConstant)
assert res[0].data == 1
assert isinstance(res[1], ScalarConstant)
assert res[1].data == 2
assert not isinstance(res[2], ScalarConstant)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论