提交 0e018bc4 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron 提交者: Olivier Delalleau

Add TensorVariable.zeros_like() and SparseVariable.zeros_like()

Code for sparse was stolen from Yann Dauphin.
上级 b9977855
...@@ -137,6 +137,11 @@ def sp_ones_like(x): ...@@ -137,6 +137,11 @@ def sp_ones_like(x):
data, indices, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats data, indices, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats
return CSM(format=x.format)(tensor.ones_like(data), indices, indptr, shape) return CSM(format=x.format)(tensor.ones_like(data), indices, indptr, shape)
def sp_zeros_like(x):
_, _, indptr, shape = csm_properties(x) #TODO: don't restrict to CSM formats
return CSM(format=x.format)(numpy.array([], dtype=x.type.dtype), numpy.array([]), tensor.zeros_like(indptr), shape)
class _sparse_py_operators: class _sparse_py_operators:
T = property(lambda self: transpose(self), doc = "Return aliased transpose of self (read-only)") T = property(lambda self: transpose(self), doc = "Return aliased transpose of self (read-only)")
def __neg__(self): return neg(self) def __neg__(self): return neg(self)
...@@ -184,6 +189,10 @@ class SparseVariable(gof.Variable, _sparse_py_operators): ...@@ -184,6 +189,10 @@ class SparseVariable(gof.Variable, _sparse_py_operators):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
def zeros_like(model, dtype=None):
# TODO: don't ignore dtype
return sp_zeros_like(model)
class SparseConstantSignature(tuple): class SparseConstantSignature(tuple):
def __eq__(self, other): def __eq__(self, other):
(a, b), (x,y) = self, other (a, b), (x,y) = self, other
......
...@@ -820,6 +820,17 @@ class UsmmTests(unittest.TestCase): ...@@ -820,6 +820,17 @@ class UsmmTests(unittest.TestCase):
for node in topo]) == nb for node in topo]) == nb
class test_zeros_like(unittest.TestCase):
def test(self):
x = theano.sparse.csr_matrix()
f = theano.function([x], theano.sparse.sp_zeros_like(x))
vx = scipy.sparse.csr_matrix(numpy.asarray(numpy.random.binomial(1, 0.5, (100, 100)), dtype=theano.config.floatX))
fx = f(vx)
assert fx.nnz == 0
assert fx.shape == vx.shape
def test_shape_i(): def test_shape_i():
sparse_dtype = 'float32' sparse_dtype = 'float32'
......
...@@ -1454,7 +1454,12 @@ class _tensor_py_operators: ...@@ -1454,7 +1454,12 @@ class _tensor_py_operators:
class TensorVariable(_tensor_py_operators, Variable): class TensorVariable(_tensor_py_operators, Variable):
"""Subclass to add the tensor operators to the basic `Variable` class.""" """Subclass to add the tensor operators to the basic `Variable` class."""
def zeros_like(model, dtype=None):
"Used for grad, Lop and Rop"
# Tested through the zeros_like method below
if dtype is None:
dtype = model.type.dtype
return fill(model, constant(0.0, dtype=dtype))
TensorType.Variable = TensorVariable TensorType.Variable = TensorVariable
...@@ -2364,10 +2369,7 @@ def ones_like(model, dtype=None): ...@@ -2364,10 +2369,7 @@ def ones_like(model, dtype=None):
@constructor @constructor
def zeros_like(model, dtype=None): def zeros_like(model, dtype=None):
"""equivalent of numpy.zeros_like""" """equivalent of numpy.zeros_like"""
if dtype is None: return TensorVariable.zeros_like(model, dtype=None)
dtype = model.type.dtype
return fill(model, constant(0.0, dtype=dtype))
def zeros(shape, dtype=config.floatX): def zeros(shape, dtype=config.floatX):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论