提交 c8ed209a authored 作者: Frederic's avatar Frederic

New internal interface to the sparse grad of AdvancedSubtensor1

上级 9345964d
...@@ -176,10 +176,8 @@ def sparse_grad(var): ...@@ -176,10 +176,8 @@ def sparse_grad(var):
""" """
assert isinstance(var.owner.op, tensor.AdvancedSubtensor1) assert isinstance(var.owner.op, tensor.AdvancedSubtensor1)
# TODO change the internal representation!!! ret = var.owner.op.__class__(sparse_grad=True)(*var.owner.inputs)
# It work, but bad as out.type is shared with var.type!!! return ret
var.owner.inputs[0].tag.sparse_grad = True
return var
import theano.tests import theano.tests
......
...@@ -450,6 +450,14 @@ class TestConstructSparseFromList(unittest.TestCase): ...@@ -450,6 +450,14 @@ class TestConstructSparseFromList(unittest.TestCase):
g = theano.grad(sub.sum(), m) g = theano.grad(sub.sum(), m)
assert isinstance(g.owner.op, ConstructSparseFromList) assert isinstance(g.owner.op, ConstructSparseFromList)
# Test that we create a sparse grad when asked
# Op INTERFACE
m = theano.tensor.matrix()
v = theano.tensor.ivector()
sub = theano.tensor.AdvancedSubtensor1(sparse_grad=True)(m, v)
g = theano.grad(sub.sum(), m)
assert isinstance(g.owner.op, ConstructSparseFromList)
# Test the sparse grad # Test the sparse grad
valm = numpy.random.rand(5, 4).astype(config.floatX) valm = numpy.random.rand(5, 4).astype(config.floatX)
valv = numpy.random.random_integers(0, 4, 10) valv = numpy.random.random_integers(0, 4, 10)
......
...@@ -706,6 +706,11 @@ class TensorType(Type): ...@@ -706,6 +706,11 @@ class TensorType(Type):
self.name = name self.name = name
self.numpy_dtype = numpy.dtype(self.dtype) self.numpy_dtype = numpy.dtype(self.dtype)
self.sparse_grad = sparse_grad self.sparse_grad = sparse_grad
if sparse_grad:
warnings.warn(
"DEPRECATION WARNING: You use an old interface to"
" AdvancedSubtensor1 sparse_grad. Now use"
" theano.sparse_grad(a_tensor[an_int_vector]).")
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
"""Convert `data` to something which can be associated to a """Convert `data` to something which can be associated to a
...@@ -7153,6 +7158,9 @@ def inverse_permutation(perm): ...@@ -7153,6 +7158,9 @@ def inverse_permutation(perm):
class AdvancedSubtensor1(Op): class AdvancedSubtensor1(Op):
"""Implement x[ilist] where ilist is a vector of integers.""" """Implement x[ilist] where ilist is a vector of integers."""
def __init__(self, sparse_grad=False):
self.sparse_grad = sparse_grad
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self))
...@@ -7212,8 +7220,14 @@ class AdvancedSubtensor1(Op): ...@@ -7212,8 +7220,14 @@ class AdvancedSubtensor1(Op):
x, ilist = inputs x, ilist = inputs
gz, = grads gz, = grads
assert len(inputs) == 2 assert len(inputs) == 2
sparse = False
if x.type.sparse_grad: if getattr(x.type, 'sparse_grad', False):
sparse = True
warnings.warn(
"DEPRECATION WARNING: AdvancedSubtensor1, you are using"
" an old interface to the sparse grad. You should use"
" theano.sparse_grad(a_tensor[an_int_vector]). ")
if sparse or self.sparse_grad:
if x.type.ndim != 2: if x.type.ndim != 2:
raise TypeError( raise TypeError(
"AdvancedSubtensor1: you can't take the sparse grad" "AdvancedSubtensor1: you can't take the sparse grad"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论