提交 867b6489 authored 作者: Vivek Kulkarni's avatar Vivek Kulkarni

All cosmeticomments fixed

上级 e7618407
...@@ -49,6 +49,9 @@ continuous_dtypes = map(str, scal.continuous_types) ...@@ -49,6 +49,9 @@ continuous_dtypes = map(str, scal.continuous_types)
discrete_dtypes = map(str, scal.discrete_types) discrete_dtypes = map(str, scal.discrete_types)
all_dtypes = map(str, scal.all_types) all_dtypes = map(str, scal.all_types)
# Do a lazy import of the sparse module
sparse_module_ref = None
class ShapeError(Exception): class ShapeError(Exception):
"""Raised when the shape cannot be computed.""" """Raised when the shape cannot be computed."""
...@@ -619,7 +622,7 @@ class TensorType(Type): ...@@ -619,7 +622,7 @@ class TensorType(Type):
Inf entries. (Used in `DebugMode`) Inf entries. (Used in `DebugMode`)
""" """
def __init__(self, dtype, broadcastable, name=None, sparsegrad=False): def __init__(self, dtype, broadcastable, name=None, sparse_grad=False):
"""Initialize self.dtype and self.broadcastable. """Initialize self.dtype and self.broadcastable.
:Parameters: :Parameters:
...@@ -644,7 +647,7 @@ class TensorType(Type): ...@@ -644,7 +647,7 @@ class TensorType(Type):
self.dtype_specs() # error checking is done there self.dtype_specs() # error checking is done there
self.name = name self.name = name
self.numpy_dtype = numpy.dtype(self.dtype) self.numpy_dtype = numpy.dtype(self.dtype)
self.sparsegrad = sparsegrad self.sparse_grad = sparse_grad
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
"""Convert `data` to something which can be associated to a """Convert `data` to something which can be associated to a
...@@ -6521,11 +6524,14 @@ class AdvancedSubtensor1(Op): ...@@ -6521,11 +6524,14 @@ class AdvancedSubtensor1(Op):
return rval return rval
def grad(self, inputs, grads): def grad(self, inputs, grads):
global sparse_module_ref
gz, = grads gz, = grads
assert len(inputs) == 2 assert len(inputs) == 2
if inputs[0].type.sparsegrad: if inputs[0].type.sparse_grad:
rval1 = [theano.sparse.ConstructSparseFromList()((inputs[0]), gz, inputs[1])] if sparse_module_ref is None:
import theano.sparse as sparse_module_ref
rval1 = [sparse_module_ref.ConstructSparseFromList()((inputs[0]), gz, inputs[1])]
else: else:
rval1 = [advanced_inc_subtensor1(zeros_like(inputs[0]), gz, inputs[1])] rval1 = [advanced_inc_subtensor1(zeros_like(inputs[0]), gz, inputs[1])]
return rval1 + [DisconnectedType()()] * (len(inputs) - 1) return rval1 + [DisconnectedType()()] * (len(inputs) - 1)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论