提交 3635eacd authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Allow get_scalar_constant_value to get shape values from constants

上级 46adff86
...@@ -5847,6 +5847,19 @@ class TestGetScalarConstantValue: ...@@ -5847,6 +5847,19 @@ class TestGetScalarConstantValue:
v = tt.row() v = tt.row()
assert get_scalar_constant_value(v.shape[0]) == 1 assert get_scalar_constant_value(v.shape[0]) == 1
res = tt.get_scalar_constant_value(tt.as_tensor([10, 20]).shape[0])
assert isinstance(res, np.ndarray)
assert 2 == res
res = tt.get_scalar_constant_value(
9 + tt.as_tensor([1.0]).shape[0],
elemwise=True,
only_process_constants=False,
max_recur=9,
)
assert isinstance(res, np.ndarray)
assert 10 == res
def test_subtensor_of_constant(self): def test_subtensor_of_constant(self):
c = constant(rand(5)) c = constant(rand(5))
for i in range(c.value.shape[0]): for i in range(c.value.shape[0]):
......
...@@ -659,6 +659,9 @@ def get_scalar_constant_value( ...@@ -659,6 +659,9 @@ def get_scalar_constant_value(
if gp_broadcastable[idx]: if gp_broadcastable[idx]:
return np.asarray(1) return np.asarray(1)
if isinstance(grandparent, Constant):
return np.asarray(grandparent.data.shape[idx])
raise NotScalarConstantError(v) raise NotScalarConstantError(v)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论