提交 c832967e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add missing get_scalar_constant_value tests

上级 5f681ce5
...@@ -3141,6 +3141,9 @@ def test_dimshuffle_duplicate(): ...@@ -3141,6 +3141,9 @@ def test_dimshuffle_duplicate():
class TestGetScalarConstantValue: class TestGetScalarConstantValue:
def test_basic(self): def test_basic(self):
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(aes.int64())
res = get_scalar_constant_value(aet.as_tensor(10)) res = get_scalar_constant_value(aet.as_tensor(10))
assert res == 10 assert res == 10
assert isinstance(res, np.ndarray) assert isinstance(res, np.ndarray)
...@@ -3190,6 +3193,11 @@ class TestGetScalarConstantValue: ...@@ -3190,6 +3193,11 @@ class TestGetScalarConstantValue:
assert isinstance(res, np.ndarray) assert isinstance(res, np.ndarray)
assert 10 == res assert 10 == res
@pytest.mark.xfail(reason="Incomplete implementation")
def test_DimShufle(self):
a = as_tensor_variable(1.0)[None][0]
assert get_scalar_constant_value(a) == 1
def test_subtensor_of_constant(self): def test_subtensor_of_constant(self):
c = constant(random(5)) c = constant(random(5))
for i in range(c.value.shape[0]): for i in range(c.value.shape[0]):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论