提交 2c9aec23 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

various sparse grads

上级 a5fa15fe
...@@ -18,6 +18,7 @@ from theano.gof.python25 import all ...@@ -18,6 +18,7 @@ from theano.gof.python25 import all
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.sparse.utils import hash_from_sparse from theano.sparse.utils import hash_from_sparse
import theano.tests.unittest_tools as utt import theano.tests.unittest_tools as utt
from theano.gradient import grad_not_implemented
sparse_formats = ['csc', 'csr'] sparse_formats = ['csc', 'csr']
...@@ -2149,7 +2150,9 @@ class MulSV(gof.op.Op): ...@@ -2149,7 +2150,9 @@ class MulSV(gof.op.Op):
assert y.type.ndim == 1 assert y.type.ndim == 1
if x.type.dtype != y.type.dtype: if x.type.dtype != y.type.dtype:
raise NotImplementedError() raise NotImplementedError(
"MulSV not implemented for differing dtypes."
"Got %s and %s." % (str(x.type.dtype),str(y.type.dtype)))
return gof.Apply(self, return gof.Apply(self,
[x, y], [x, y],
[SparseType(dtype=x.type.dtype, [SparseType(dtype=x.type.dtype,
...@@ -2163,6 +2166,15 @@ class MulSV(gof.op.Op): ...@@ -2163,6 +2166,15 @@ class MulSV(gof.op.Op):
def grad(self, (x, y), (gz,)): def grad(self, (x, y), (gz,)):
assert _is_sparse_variable(x) and _is_dense_variable(y) assert _is_sparse_variable(x) and _is_dense_variable(y)
assert _is_sparse_variable(gz) assert _is_sparse_variable(gz)
# mul_s_v is not implemented if the types vary
if gz.dtype == 'float64' and y.dtype == 'float32':
y = y.astype('float64')
if gz.dtype == 'float32' and y.dtype == 'float64':
gz = gz.astype('float64')
return mul_s_v(gz, y), sp_sum(x * gz, axis=0, sparse_grad=True) return mul_s_v(gz, y), sp_sum(x * gz, axis=0, sparse_grad=True)
def infer_shape(self, node, ins_shapes): def infer_shape(self, node, ins_shapes):
...@@ -2197,6 +2209,11 @@ def mul(x, y): ...@@ -2197,6 +2209,11 @@ def mul(x, y):
assert x_is_sparse_variable or y_is_sparse_variable assert x_is_sparse_variable or y_is_sparse_variable
if x_is_sparse_variable and y_is_sparse_variable: if x_is_sparse_variable and y_is_sparse_variable:
# mul_s_s is not implemented if the types differ
if y.dtype == 'float64' and x.dtype == 'float32':
x = x.astype('float64')
return mul_s_s(x, y) return mul_s_s(x, y)
elif x_is_sparse_variable and not y_is_sparse_variable: elif x_is_sparse_variable and not y_is_sparse_variable:
...@@ -3286,7 +3303,7 @@ class SamplingDot(gof.op.Op): ...@@ -3286,7 +3303,7 @@ class SamplingDot(gof.op.Op):
rval = [ rval = [
dot(p * gz, y), dot(p * gz, y),
dot((p * gz).T, x), dot((p * gz).T, x),
None grad_not_implemented(self, 2, p)
] ]
return rval return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论