提交 c39ddb2e authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

need to also exclude tensorvariable arguments (and tests)

上级 6a16136d
......@@ -2064,7 +2064,7 @@ def zeros(shape, dtype=None):
"""
Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``.
"""
if not isinstance(shape, (list, tuple)):
if not isinstance(shape, (list, tuple, TensorVariable)):
shape = [shape]
if dtype is None:
dtype = config.floatX
......@@ -2075,7 +2075,7 @@ def ones(shape, dtype=None):
"""
Create a Tensor filled with ones, closer to Numpy's syntax than ``alloc``.
"""
if not isinstance(shape, (list, tuple)):
if not isinstance(shape, (list, tuple, TensorVariable)):
shape = [shape]
if dtype is None:
dtype = config.floatX
......
......@@ -1897,17 +1897,39 @@ class TestAlloc(unittest.TestCase):
assert not isinstance(topo[0].op, DeepCopyOp)
def test_ones(self):
shapes = [[], 1, [1], [1, 2], [1, 2, 3]]
for shp in shapes:
for shp in [[], 1, [1], [1, 2], [1, 2, 3]]:
ones = theano.function([], [tensor.ones(shp)])
assert numpy.allclose(ones(), numpy.ones(shp))
# scalar doesn't have to be provided as input
x = scalar()
shp = []
ones_scalar = theano.function([], [tensor.ones(x.shape)])
assert numpy.allclose(ones_scalar(), numpy.ones(shp))
for (typ, shp) in [(vector, [3]), (matrix, [3,4])]:
x = typ()
ones_tensor = theano.function([x], [tensor.ones(x.shape)])
assert numpy.allclose(ones_tensor(numpy.zeros(shp)),
numpy.ones(shp))
def test_zeros(self):
shapes = [[], 1, [1], [1, 2], [1, 2, 3]]
for shp in shapes:
for shp in [[], 1, [1], [1, 2], [1, 2, 3]]:
zeros = theano.function([], [tensor.zeros(shp)])
assert numpy.allclose(zeros(), numpy.zeros(shp))
# scalar doesn't have to be provided as input
x = scalar()
shp = []
zeros_scalar = theano.function([], [tensor.zeros(x.shape)])
assert numpy.allclose(zeros_scalar(), numpy.zeros(shp))
for (typ, shp) in [(vector, [3]), (matrix, [3,4])]:
x = typ()
zeros_tensor = theano.function([x], [tensor.zeros(x.shape)])
assert numpy.allclose(zeros_tensor(numpy.zeros(shp)),
numpy.zeros(shp))
def test_eye():
def check(dtype, N, M_=None, k=0):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论