提交 6ef1452a authored 作者: danhphan's avatar danhphan 提交者: Ricardo Vieira

Fix tensor.zeros and tensor.ones with symbolic scalar

上级 5a0fb0e7
......@@ -978,7 +978,10 @@ def zeros_like(model, dtype=None, opt=False):
def zeros(shape, dtype=None):
"""Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``."""
if not isinstance(shape, (np.ndarray, Sequence, TensorVariable)):
if not (
isinstance(shape, (np.ndarray, Sequence))
or (isinstance(shape, TensorVariable) and shape.ndim > 0)
):
shape = [shape]
if dtype is None:
dtype = config.floatX
......@@ -987,7 +990,10 @@ def zeros(shape, dtype=None):
def ones(shape, dtype=None):
"""Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``."""
if not isinstance(shape, (np.ndarray, Sequence, TensorVariable)):
if not (
isinstance(shape, (np.ndarray, Sequence))
or (isinstance(shape, TensorVariable) and shape.ndim > 0)
):
shape = [shape]
if dtype is None:
dtype = config.floatX
......@@ -4274,6 +4280,11 @@ def empty(shape, dtype=None):
Desired output data-type for the array, e.g, `numpy.int8`. Default is
`numpy.float64`.
"""
if not (
isinstance(shape, (np.ndarray, Sequence))
or (isinstance(shape, TensorVariable) and shape.ndim > 0)
):
shape = [shape]
if dtype is None:
dtype = config.floatX
return AllocEmpty(dtype)(*shape)
......
......@@ -754,6 +754,11 @@ class TestAlloc:
for shp in [[], 1, [1], [1, 2], [1, 2, 3], np.r_[1, 2, 3]]:
ones = aesara.function([], [at.ones(shp)], mode=self.mode)
assert np.allclose(ones(), np.ones(shp))
# When shape is a TensorConstant
ones_const = aesara.function(
[], [at.ones(at.constant(shp))], mode=self.mode
)
assert np.allclose(ones_const(), np.ones(shp))
# scalar doesn't have to be provided as input
x = scalar()
......@@ -771,6 +776,11 @@ class TestAlloc:
for shp in [[], 1, [1], [1, 2], [1, 2, 3], np.r_[1, 2, 3]]:
zeros = aesara.function([], [at.zeros(shp)], mode=self.mode)
assert np.allclose(zeros(), np.zeros(shp))
# When shape is a TensorConstant
zeros_const = aesara.function(
[], [at.zeros(at.constant(shp))], mode=self.mode
)
assert np.allclose(zeros_const(), np.zeros(shp))
# scalar doesn't have to be provided as input
x = scalar()
......@@ -4381,6 +4391,10 @@ def test_empty():
assert out.shape == (2, 3)
assert out.dtype == "float32"
empty_at = at.empty(3)
res = aesara.function([], empty_at)()
assert res.shape == (3,)
empty_at = at.empty((2, 3), dtype=None)
res = aesara.function([], empty_at)()
assert res.shape == (2, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论