Unverified 提交 c9f5f656 authored 作者: Will Dean's avatar Will Dean 提交者: GitHub

Support int-like shapes in `pt.full` (#759)

上级 146a0a84
...@@ -1719,6 +1719,9 @@ def full(shape, fill_value, dtype=None): ...@@ -1719,6 +1719,9 @@ def full(shape, fill_value, dtype=None):
fill_value = as_tensor_variable(fill_value) fill_value = as_tensor_variable(fill_value)
if dtype: if dtype:
fill_value = fill_value.astype(dtype) fill_value = fill_value.astype(dtype)
if np.ndim(shape) == 0:
shape = (shape,)
return alloc(fill_value, *shape) return alloc(fill_value, *shape)
......
...@@ -848,10 +848,15 @@ class TestAlloc: ...@@ -848,10 +848,15 @@ class TestAlloc:
inp = np.zeros(shp, dtype=config.floatX) inp = np.zeros(shp, dtype=config.floatX)
assert np.allclose(zeros_tensor(inp), np.zeros(shp)) assert np.allclose(zeros_tensor(inp), np.zeros(shp))
def test_full(self): @pytest.mark.parametrize(
full_pt = ptb.full((2, 3), 3, dtype="int64") "shape", [(2, 3), 5, np.int32(5), np.array(5), constant(5)]
)
def test_full(self, shape):
full_pt = ptb.full(shape, 3, dtype="int64")
res = pytensor.function([], full_pt, mode=self.mode)() res = pytensor.function([], full_pt, mode=self.mode)()
assert np.array_equal(res, np.full((2, 3), 3, dtype="int64")) if isinstance(shape, ptb.TensorVariable):
shape = shape.eval()
assert np.array_equal(res, np.full(shape, 3, dtype="int64"))
@pytest.mark.parametrize("func", (ptb.zeros, ptb.empty)) @pytest.mark.parametrize("func", (ptb.zeros, ptb.empty))
def test_rebuild(self, func): def test_rebuild(self, func):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论