提交 2c9aec23 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

various sparse grads

上级 a5fa15fe
......@@ -18,6 +18,7 @@ from theano.gof.python25 import all
from theano.gradient import DisconnectedType
from theano.sparse.utils import hash_from_sparse
import theano.tests.unittest_tools as utt
from theano.gradient import grad_not_implemented
sparse_formats = ['csc', 'csr']
......@@ -2149,7 +2150,9 @@ class MulSV(gof.op.Op):
assert y.type.ndim == 1
if x.type.dtype != y.type.dtype:
raise NotImplementedError()
raise NotImplementedError(
"MulSV not implemented for differing dtypes."
"Got %s and %s." % (str(x.type.dtype),str(y.type.dtype)))
return gof.Apply(self,
[x, y],
[SparseType(dtype=x.type.dtype,
......@@ -2163,6 +2166,15 @@ class MulSV(gof.op.Op):
def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_dense_variable(y)
assert _is_sparse_variable(gz)
# mul_s_v is not implemented if the types vary
if gz.dtype == 'float64' and y.dtype == 'float32':
y = y.astype('float64')
if gz.dtype == 'float32' and y.dtype == 'float64':
gz = gz.astype('float64')
return mul_s_v(gz, y), sp_sum(x * gz, axis=0, sparse_grad=True)
def infer_shape(self, node, ins_shapes):
......@@ -2197,6 +2209,11 @@ def mul(x, y):
assert x_is_sparse_variable or y_is_sparse_variable
if x_is_sparse_variable and y_is_sparse_variable:
# mul_s_s is not implemented if the types differ
if y.dtype == 'float64' and x.dtype == 'float32':
x = x.astype('float64')
return mul_s_s(x, y)
elif x_is_sparse_variable and not y_is_sparse_variable:
......@@ -3286,7 +3303,7 @@ class SamplingDot(gof.op.Op):
rval = [
dot(p * gz, y),
dot((p * gz).T, x),
None
grad_not_implemented(self, 2, p)
]
return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论