提交 6a16136d authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

allow ones/zeros to accept single integer shapes (plus tests)

上级 d86ac0d3
...@@ -2064,6 +2064,8 @@ def zeros(shape, dtype=None): ...@@ -2064,6 +2064,8 @@ def zeros(shape, dtype=None):
""" """
Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``. Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``.
""" """
if not isinstance(shape, (list, tuple)):
shape = [shape]
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
return alloc(numpy.array(0, dtype=dtype), *shape) return alloc(numpy.array(0, dtype=dtype), *shape)
...@@ -2073,6 +2075,8 @@ def ones(shape, dtype=None): ...@@ -2073,6 +2075,8 @@ def ones(shape, dtype=None):
""" """
Create a Tensor filled with ones, closer to Numpy's syntax than ``alloc``. Create a Tensor filled with ones, closer to Numpy's syntax than ``alloc``.
""" """
if not isinstance(shape, (list, tuple)):
shape = [shape]
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
return alloc(numpy.array(1, dtype=dtype), *shape) return alloc(numpy.array(1, dtype=dtype), *shape)
......
...@@ -1896,6 +1896,18 @@ class TestAlloc(unittest.TestCase): ...@@ -1896,6 +1896,18 @@ class TestAlloc(unittest.TestCase):
for node in topo]) == 1 for node in topo]) == 1
assert not isinstance(topo[0].op, DeepCopyOp) assert not isinstance(topo[0].op, DeepCopyOp)
def test_ones(self):
shapes = [[], 1, [1], [1, 2], [1, 2, 3]]
for shp in shapes:
ones = theano.function([], [tensor.ones(shp)])
assert numpy.allclose(ones(), numpy.ones(shp))
def test_zeros(self):
shapes = [[], 1, [1], [1, 2], [1, 2, 3]]
for shp in shapes:
zeros = theano.function([], [tensor.zeros(shp)])
assert numpy.allclose(zeros(), numpy.zeros(shp))
def test_eye(): def test_eye():
def check(dtype, N, M_=None, k=0): def check(dtype, N, M_=None, k=0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论