提交 13de103e authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

fix bugs from the gradient of Cast

上级 dcc7944a
...@@ -34,8 +34,7 @@ class Cast(gof.op.Op): ...@@ -34,8 +34,7 @@ class Cast(gof.op.Op):
def make_node(self, x): def make_node(self, x):
x = as_sparse_variable(x) x = as_sparse_variable(x)
return gof.Apply( return gof.Apply(
self, self, [x],
[x],
[SparseType(dtype=self.out_type, format=x.format).make_variable()]) [SparseType(dtype=self.out_type, format=x.format).make_variable()])
def perform(self, node, (x, ), (out, )): def perform(self, node, (x, ), (out, )):
...@@ -45,7 +44,7 @@ class Cast(gof.op.Op): ...@@ -45,7 +44,7 @@ class Cast(gof.op.Op):
def grad(self, inputs, outputs_gradients): def grad(self, inputs, outputs_gradients):
if inputs[0].dtype in T.continuous_dtypes: if inputs[0].dtype in T.continuous_dtypes:
gz = outputs_gradients[0] gz = outputs_gradients[0]
return [Cast(self.out_type)(gz)] return [Cast(inputs[0].dtype)(gz)]
else: else:
return [None] return [None]
...@@ -56,7 +55,7 @@ class Cast(gof.op.Op): ...@@ -56,7 +55,7 @@ class Cast(gof.op.Op):
return self.__class__.__name__ return self.__class__.__name__
def astype(x, t): def cast(x, t):
"""Cast sparse variable `x` to the desired dtype `t`. """Cast sparse variable `x` to the desired dtype `t`.
This wrap the method astype from scipy. This wrap the method astype from scipy.
......
...@@ -66,12 +66,12 @@ class TestCast(utt.InferShapeTester): ...@@ -66,12 +66,12 @@ class TestCast(utt.InferShapeTester):
def test_cast(self): def test_cast(self):
cast_csc = dict([ cast_csc = dict([
(x, [theano.function([x], S2.astype(x, t)) (x, [theano.function([x], S2.Cast(t)(x))
for t in self.compatible_types]) for t in self.compatible_types])
for x in self.x_csc]) for x in self.x_csc])
cast_csr = dict([ cast_csr = dict([
(x, [theano.function([x], S2.astype(x, t)) (x, [theano.function([x], S2.Cast(t)(x))
for t in self.compatible_types]) for t in self.compatible_types])
for x in self.x_csr]) for x in self.x_csr])
...@@ -90,7 +90,7 @@ class TestCast(utt.InferShapeTester): ...@@ -90,7 +90,7 @@ class TestCast(utt.InferShapeTester):
for t in self.compatible_types: for t in self.compatible_types:
a = sp.csc_matrix(self.properties, dtype=x.dtype) a = sp.csc_matrix(self.properties, dtype=x.dtype)
self._compile_and_check([x], self._compile_and_check([x],
[S2.astype(x, t)], [S2.Cast(t)(x)],
[a], [a],
self.op_class) self.op_class)
...@@ -98,25 +98,18 @@ class TestCast(utt.InferShapeTester): ...@@ -98,25 +98,18 @@ class TestCast(utt.InferShapeTester):
for t in self.compatible_types: for t in self.compatible_types:
a = sp.csr_matrix(self.properties, dtype=x.dtype) a = sp.csr_matrix(self.properties, dtype=x.dtype)
self._compile_and_check([x], self._compile_and_check([x],
[S2.astype(x, t)], [S2.Cast(t)(x)],
[a], [a],
self.op_class) self.op_class)
def test_grad(self): def test_grad(self):
x_csc = [S.csc_matrix(dtype=t) for t in T.float_dtypes] for dtype in T.float_dtypes:
x_csr = [S.csr_matrix(dtype=t) for t in T.float_dtypes] a = sp.csc_matrix(self.properties, dtype=dtype)
verify_grad_sparse(S2.Cast('float64'), [a])
# There is a problem with the grad
# TODO Find the problem for dtype in T.float_dtypes:
# for x in x_csc: a = sp.csr_matrix(self.properties, dtype=dtype)
# for t in T.float_dtypes: verify_grad_sparse(S2.Cast('float64'), [a])
# a = sp.csc_matrix(self.properties, dtype=x.dtype)
# verify_grad_sparse(S2.Cast(t), [a])
# for x in x_csr:
# for t in T.float_dtypes:
# a = sp.csr_matrix(self.properties, dtype=x.dtype)
# verify_grad_sparse(S2.Cast(t), [a])
class test_structured_add_s_v(unittest.TestCase): class test_structured_add_s_v(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论