提交 f1dc0897 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make shape_tuple return all static shape information

上级 456cce1d
......@@ -8,6 +8,7 @@ import numpy as np
import aesara
from aesara.gradient import DisconnectedType
from aesara.graph.basic import Apply, Variable
from aesara.graph.type import HasShape
from aesara.link.c.op import COp
from aesara.link.c.params_type import ParamsType
from aesara.misc.safe_asarray import _asarray
......@@ -158,18 +159,28 @@ def _get_vector_length_Shape(op, var):
def shape_tuple(x: TensorVariable) -> Tuple[Variable, ...]:
"""Get a tuple of symbolic shape values.
r"""Get a tuple of symbolic shape values.
This will return `ScalarConstant`\s for static shape values.
This will return a `ScalarConstant` with the value ``1`` wherever
broadcastable is ``True``.
"""
one_at = aesara.scalar.ScalarConstant(aesara.scalar.int64, 1)
return tuple(
one_at if getattr(sh, "value", sh) == 1 or bcast else sh
for sh, bcast in zip(
shape(x), getattr(x, "broadcastable", (False,) * x.type.ndim)
)
)
if not isinstance(x.type, HasShape):
# We assume/call it a scalar
return ()
res = ()
symbolic_shape = shape(x)
static_shape = x.type.shape
for i in range(x.type.ndim):
shape_val = static_shape[i]
if shape_val is not None:
# TODO: Why not use uint64?
res += (aesara.scalar.ScalarConstant(aesara.scalar.int64, shape_val),)
else:
res += (symbolic_shape[i],)
return res
class Shape_i(COp):
......
......@@ -9,6 +9,7 @@ from aesara.graph.basic import Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type
from aesara.misc.safe_asarray import _asarray
from aesara.scalar.basic import ScalarConstant
from aesara.tensor import as_tensor_variable, get_vector_length, row
from aesara.tensor.basic import MakeVector, constant
from aesara.tensor.elemwise import DimShuffle, Elemwise
......@@ -22,6 +23,7 @@ from aesara.tensor.shape import (
reshape,
shape,
shape_i,
shape_tuple,
specify_broadcastable,
specify_shape,
unbroadcast,
......@@ -46,6 +48,7 @@ from aesara.tensor.type_other import NoneConst
from aesara.tensor.var import TensorVariable
from aesara.typed_list import make_list
from tests import unittest_tools as utt
from tests.graph.utils import MyType2
from tests.tensor.utils import eval_outputs, random
from tests.test_rop import RopLopChecker
......@@ -657,3 +660,18 @@ class TestUnbroadcastInferShape(utt.InferShapeTester):
Unbroadcast,
warn=False,
)
def test_shape_tuple():
x = Variable(MyType2(), None, None)
assert shape_tuple(x) == ()
x = tensor(np.float64, shape=(1, 2, None))
res = shape_tuple(x)
assert isinstance(res, tuple)
assert isinstance(res[0], ScalarConstant)
assert res[0].data == 1
assert isinstance(res[1], ScalarConstant)
assert res[1].data == 2
assert not isinstance(res[2], ScalarConstant)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论