提交 91e20cc4 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Generalize the type interface to tensor.zeros and tensor.ones

上级 746eeac4
...@@ -2877,7 +2877,7 @@ class TestAlloc: ...@@ -2877,7 +2877,7 @@ class TestAlloc:
assert not isinstance(topo[0].op, DeepCopyOp) assert not isinstance(topo[0].op, DeepCopyOp)
def test_ones(self): def test_ones(self):
for shp in [[], 1, [1], [1, 2], [1, 2, 3]]: for shp in [[], 1, [1], [1, 2], [1, 2, 3], np.r_[1, 2, 3]]:
ones = theano.function([], [tensor.ones(shp)], mode=self.mode) ones = theano.function([], [tensor.ones(shp)], mode=self.mode)
assert np.allclose(ones(), np.ones(shp)) assert np.allclose(ones(), np.ones(shp))
...@@ -2894,7 +2894,7 @@ class TestAlloc: ...@@ -2894,7 +2894,7 @@ class TestAlloc:
assert np.allclose(ones_tensor(inp), np.ones(shp)) assert np.allclose(ones_tensor(inp), np.ones(shp))
def test_zeros(self): def test_zeros(self):
for shp in [[], 1, [1], [1, 2], [1, 2, 3]]: for shp in [[], 1, [1], [1, 2], [1, 2, 3], np.r_[1, 2, 3]]:
zeros = theano.function([], [tensor.zeros(shp)], mode=self.mode) zeros = theano.function([], [tensor.zeros(shp)], mode=self.mode)
assert np.allclose(zeros(), np.zeros(shp)) assert np.allclose(zeros(), np.zeros(shp))
......
...@@ -12,6 +12,7 @@ import theano ...@@ -12,6 +12,7 @@ import theano
import theano.scalar.sharedvar import theano.scalar.sharedvar
from functools import partial from functools import partial
from collections.abc import Sequence
from six import integer_types from six import integer_types
...@@ -2683,7 +2684,7 @@ def zeros(shape, dtype=None): ...@@ -2683,7 +2684,7 @@ def zeros(shape, dtype=None):
""" """
Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``. Create a Tensor filled with zeros, closer to Numpy's syntax than ``alloc``.
""" """
if not isinstance(shape, (list, tuple, TensorVariable)): if not isinstance(shape, (np.ndarray, Sequence, TensorVariable)):
shape = [shape] shape = [shape]
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
...@@ -2694,7 +2695,7 @@ def ones(shape, dtype=None): ...@@ -2694,7 +2695,7 @@ def ones(shape, dtype=None):
""" """
Create a Tensor filled with ones, closer to Numpy's syntax than ``alloc``. Create a Tensor filled with ones, closer to Numpy's syntax than ``alloc``.
""" """
if not isinstance(shape, (list, tuple, TensorVariable)): if not isinstance(shape, (np.ndarray, Sequence, TensorVariable)):
shape = [shape] shape = [shape]
if dtype is None: if dtype is None:
dtype = config.floatX dtype = config.floatX
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论