提交 511b593d authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Thomas Wiecki

Move theano.sparse_grad to theano.sparse.sparse_grad

上级 a8c334aa
...@@ -59,5 +59,3 @@ There are also some top-level imports that you might find more convenient: ...@@ -59,5 +59,3 @@ There are also some top-level imports that you might find more convenient:
Works like :func:`tensor.dot` for both sparse and dense matrix products Works like :func:`tensor.dot` for both sparse and dense matrix products
.. autofunction:: theano.clone_replace .. autofunction:: theano.clone_replace
.. autofunction:: theano.sparse_grad
...@@ -297,4 +297,4 @@ List of Implemented Operations ...@@ -297,4 +297,4 @@ List of Implemented Operations
.. automodule:: theano.sparse.basic .. automodule:: theano.sparse.basic
:members: :members:
.. autofunction:: tests.sparse.test_basic.sparse_random_inputs .. autofunction:: theano.sparse.sparse_grad
...@@ -656,7 +656,7 @@ class TestConstructSparseFromList: ...@@ -656,7 +656,7 @@ class TestConstructSparseFromList:
# USER INTERFACE # USER INTERFACE
m = matrix() m = matrix()
v = ivector() v = ivector()
sub = theano.sparse_grad(m[v]) sub = theano.sparse.sparse_grad(m[v])
g = theano.grad(sub.sum(), m) g = theano.grad(sub.sum(), m)
assert isinstance(g.owner.op, ConstructSparseFromList) assert isinstance(g.owner.op, ConstructSparseFromList)
...@@ -675,7 +675,7 @@ class TestConstructSparseFromList: ...@@ -675,7 +675,7 @@ class TestConstructSparseFromList:
shared_v = theano.shared(valv) shared_v = theano.shared(valv)
def fn(m): def fn(m):
return theano.sparse_grad(m[shared_v]) return theano.sparse.sparse_grad(m[shared_v])
verify_grad_sparse(fn, [valm]) verify_grad_sparse(fn, [valm])
...@@ -691,7 +691,7 @@ class TestConstructSparseFromList: ...@@ -691,7 +691,7 @@ class TestConstructSparseFromList:
# Test that we raise an error, as we can't create a sparse # Test that we raise an error, as we can't create a sparse
# grad from tensors that don't have 2 dimensions. # grad from tensors that don't have 2 dimensions.
sub = theano.sparse_grad(sub) sub = theano.sparse.sparse_grad(sub)
with pytest.raises(TypeError): with pytest.raises(TypeError):
theano.grad(sub.sum(), t) theano.grad(sub.sum(), t)
......
...@@ -129,23 +129,6 @@ def get_scalar_constant_value(v): ...@@ -129,23 +129,6 @@ def get_scalar_constant_value(v):
return tensor.get_scalar_constant_value(v) return tensor.get_scalar_constant_value(v)
def sparse_grad(var):
"""This function return a new variable whose gradient will be
stored in a sparse format instead of dense.
Currently only variable created by AdvancedSubtensor1 is supported.
i.e. a_tensor_var[an_int_vector].
.. versionadded:: 0.6rc4
"""
from theano.tensor.subtensor import AdvancedSubtensor1
assert isinstance(var.owner.op, AdvancedSubtensor1)
ret = var.owner.op.__class__(sparse_grad=True)(*var.owner.inputs)
return ret
import theano.tensor.random.var import theano.tensor.random.var
from theano.graph.basic import clone_replace from theano.graph.basic import clone_replace
from theano.scan import checkpoints from theano.scan import checkpoints
......
...@@ -9,10 +9,26 @@ except ImportError: ...@@ -9,10 +9,26 @@ except ImportError:
enable_sparse = False enable_sparse = False
warn("SciPy can't be imported. Sparse matrix support is disabled.") warn("SciPy can't be imported. Sparse matrix support is disabled.")
from theano.sparse.type import * from theano.sparse.type import SparseType, _is_sparse
if enable_sparse: if enable_sparse:
from theano.sparse import opt, sharedvar from theano.sparse import opt, sharedvar
from theano.sparse.basic import * from theano.sparse.basic import *
from theano.sparse.sharedvar import sparse_constructor as shared from theano.sparse.sharedvar import sparse_constructor as shared
def sparse_grad(var):
"""This function return a new variable whose gradient will be
stored in a sparse format instead of dense.
Currently only variable created by AdvancedSubtensor1 is supported.
i.e. a_tensor_var[an_int_vector].
.. versionadded:: 0.6rc4
"""
from theano.tensor.subtensor import AdvancedSubtensor1
assert isinstance(var.owner.op, AdvancedSubtensor1)
ret = var.owner.op.__class__(sparse_grad=True)(*var.owner.inputs)
return ret
...@@ -71,7 +71,7 @@ class TensorType(CType): ...@@ -71,7 +71,7 @@ class TensorType(CType):
warnings.warn( warnings.warn(
"You use an old interface to" "You use an old interface to"
" AdvancedSubtensor1 sparse_grad. Now use" " AdvancedSubtensor1 sparse_grad. Now use"
" theano.sparse_grad(a_tensor[an_int_vector]).", " theano.sparse.sparse_grad(a_tensor[an_int_vector]).",
category=DeprecationWarning, category=DeprecationWarning,
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论