提交 13de103e authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

fix bugs from the gradient of Cast

上级 dcc7944a
......@@ -34,8 +34,7 @@ class Cast(gof.op.Op):
def make_node(self, x):
x = as_sparse_variable(x)
return gof.Apply(
self,
[x],
self, [x],
[SparseType(dtype=self.out_type, format=x.format).make_variable()])
def perform(self, node, (x, ), (out, )):
......@@ -45,7 +44,7 @@ class Cast(gof.op.Op):
def grad(self, inputs, outputs_gradients):
if inputs[0].dtype in T.continuous_dtypes:
gz = outputs_gradients[0]
return [Cast(self.out_type)(gz)]
return [Cast(inputs[0].dtype)(gz)]
else:
return [None]
......@@ -56,7 +55,7 @@ class Cast(gof.op.Op):
return self.__class__.__name__
def astype(x, t):
def cast(x, t):
"""Cast sparse variable `x` to the desired dtype `t`.
This wrap the method astype from scipy.
......
......@@ -66,12 +66,12 @@ class TestCast(utt.InferShapeTester):
def test_cast(self):
cast_csc = dict([
(x, [theano.function([x], S2.astype(x, t))
(x, [theano.function([x], S2.Cast(t)(x))
for t in self.compatible_types])
for x in self.x_csc])
cast_csr = dict([
(x, [theano.function([x], S2.astype(x, t))
(x, [theano.function([x], S2.Cast(t)(x))
for t in self.compatible_types])
for x in self.x_csr])
......@@ -90,7 +90,7 @@ class TestCast(utt.InferShapeTester):
for t in self.compatible_types:
a = sp.csc_matrix(self.properties, dtype=x.dtype)
self._compile_and_check([x],
[S2.astype(x, t)],
[S2.Cast(t)(x)],
[a],
self.op_class)
......@@ -98,25 +98,18 @@ class TestCast(utt.InferShapeTester):
for t in self.compatible_types:
a = sp.csr_matrix(self.properties, dtype=x.dtype)
self._compile_and_check([x],
[S2.astype(x, t)],
[S2.Cast(t)(x)],
[a],
self.op_class)
def test_grad(self):
x_csc = [S.csc_matrix(dtype=t) for t in T.float_dtypes]
x_csr = [S.csr_matrix(dtype=t) for t in T.float_dtypes]
# There is a problem with the grad
# TODO Find the problem
# for x in x_csc:
# for t in T.float_dtypes:
# a = sp.csc_matrix(self.properties, dtype=x.dtype)
# verify_grad_sparse(S2.Cast(t), [a])
# for x in x_csr:
# for t in T.float_dtypes:
# a = sp.csr_matrix(self.properties, dtype=x.dtype)
# verify_grad_sparse(S2.Cast(t), [a])
for dtype in T.float_dtypes:
a = sp.csc_matrix(self.properties, dtype=dtype)
verify_grad_sparse(S2.Cast('float64'), [a])
for dtype in T.float_dtypes:
a = sp.csr_matrix(self.properties, dtype=dtype)
verify_grad_sparse(S2.Cast('float64'), [a])
class test_structured_add_s_v(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论