提交 416e7060 authored 作者: Rami Al-Rfou''s avatar Rami Al-Rfou'

Merge pull request #4 from viveksck/TheanoBugFix

Lazy import of sparse module and rename sparsegrad according to python coding style.
......@@ -49,6 +49,9 @@ continuous_dtypes = map(str, scal.continuous_types)
discrete_dtypes = map(str, scal.discrete_types)
all_dtypes = map(str, scal.all_types)
# Do a lazy import of the sparse module
sparse_module_ref = None
class ShapeError(Exception):
"""Raised when the shape cannot be computed."""
......@@ -619,7 +622,7 @@ class TensorType(Type):
Inf entries. (Used in `DebugMode`)
"""
def __init__(self, dtype, broadcastable, name=None, sparsegrad=False):
def __init__(self, dtype, broadcastable, name=None, sparse_grad=False):
"""Initialize self.dtype and self.broadcastable.
:Parameters:
......@@ -644,7 +647,7 @@ class TensorType(Type):
self.dtype_specs() # error checking is done there
self.name = name
self.numpy_dtype = numpy.dtype(self.dtype)
self.sparsegrad = sparsegrad
self.sparse_grad = sparse_grad
def filter(self, data, strict=False, allow_downcast=None):
"""Convert `data` to something which can be associated to a
......@@ -6521,11 +6524,14 @@ class AdvancedSubtensor1(Op):
return rval
def grad(self, inputs, grads):
global sparse_module_ref
gz, = grads
assert len(inputs) == 2
if inputs[0].type.sparsegrad:
rval1 = [theano.sparse.ConstructSparseFromList()((inputs[0]), gz, inputs[1])]
if inputs[0].type.sparse_grad:
if sparse_module_ref is None:
import theano.sparse as sparse_module_ref
rval1 = [sparse_module_ref.ConstructSparseFromList()((inputs[0]), gz, inputs[1])]
else:
rval1 = [advanced_inc_subtensor1(zeros_like(inputs[0]), gz, inputs[1])]
return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论