提交 6a16136d authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

allow ones/zeros to accept single integer shapes (plus tests)

上级 d86ac0d3
......@@ -2064,6 +2064,8 @@ def zeros(shape, dtype=None):
"""
Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``.
"""
if not isinstance(shape, (list, tuple)):
shape = [shape]
if dtype is None:
dtype = config.floatX
return alloc(numpy.array(0, dtype=dtype), *shape)
......@@ -2073,6 +2075,8 @@ def ones(shape, dtype=None):
"""
Create a Tensor filled with ones, closer to Numpy's syntax than ``alloc``.
"""
if not isinstance(shape, (list, tuple)):
shape = [shape]
if dtype is None:
dtype = config.floatX
return alloc(numpy.array(1, dtype=dtype), *shape)
......
......@@ -1896,6 +1896,18 @@ class TestAlloc(unittest.TestCase):
for node in topo]) == 1
assert not isinstance(topo[0].op, DeepCopyOp)
def test_ones(self):
shapes = [[], 1, [1], [1, 2], [1, 2, 3]]
for shp in shapes:
ones = theano.function([], [tensor.ones(shp)])
assert numpy.allclose(ones(), numpy.ones(shp))
def test_zeros(self):
shapes = [[], 1, [1], [1, 2], [1, 2, 3]]
for shp in shapes:
zeros = theano.function([], [tensor.zeros(shp)])
assert numpy.allclose(zeros(), numpy.zeros(shp))
def test_eye():
def check(dtype, N, M_=None, k=0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论