提交 91e20cc4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Generalize the type interface to tensor.zeros and tensor.ones

上级 746eeac4
......@@ -2877,7 +2877,7 @@ class TestAlloc:
assert not isinstance(topo[0].op, DeepCopyOp)
def test_ones(self):
for shp in [[], 1, [1], [1, 2], [1, 2, 3]]:
for shp in [[], 1, [1], [1, 2], [1, 2, 3], np.r_[1, 2, 3]]:
ones = theano.function([], [tensor.ones(shp)], mode=self.mode)
assert np.allclose(ones(), np.ones(shp))
......@@ -2894,7 +2894,7 @@ class TestAlloc:
assert np.allclose(ones_tensor(inp), np.ones(shp))
def test_zeros(self):
for shp in [[], 1, [1], [1, 2], [1, 2, 3]]:
for shp in [[], 1, [1], [1, 2], [1, 2, 3], np.r_[1, 2, 3]]:
zeros = theano.function([], [tensor.zeros(shp)], mode=self.mode)
assert np.allclose(zeros(), np.zeros(shp))
......
......@@ -12,6 +12,7 @@ import theano
import theano.scalar.sharedvar
from functools import partial
from collections.abc import Sequence
from six import integer_types
......@@ -2683,7 +2684,7 @@ def zeros(shape, dtype=None):
"""
Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``.
"""
if not isinstance(shape, (list, tuple, TensorVariable)):
if not isinstance(shape, (np.ndarray, Sequence, TensorVariable)):
shape = [shape]
if dtype is None:
dtype = config.floatX
......@@ -2694,7 +2695,7 @@ def ones(shape, dtype=None):
"""
Create a Tensor filled with ones, closer to Numpy's syntax than ``alloc``.
"""
if not isinstance(shape, (list, tuple, TensorVariable)):
if not isinstance(shape, (np.ndarray, Sequence, TensorVariable)):
shape = [shape]
if dtype is None:
dtype = config.floatX
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论