提交 84a1a252 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

Fix problem with verify_grad for hstack, vstack and cast

上级 a3f52604
......@@ -105,17 +105,21 @@ class TestCast(utt.InferShapeTester):
self.op_class)
def test_grad(self):
# TODO Find the problem with the grad of downcast
# a = sp.csc_matrix(self.properties, dtype='float64')
# verify_grad_sparse(S2.Cast('float32'), [a], cast_to_output_type=True)
for dtype in tensor.float_dtypes:
a = sp.csc_matrix(self.properties, dtype=dtype)
verify_grad_sparse(S2.Cast('float64'), [a])
for t in tensor.float_dtypes:
eps = None
if t == 'float32':
eps = 7e-4
a = sp.csc_matrix(self.properties, dtype=dtype)
verify_grad_sparse(S2.Cast(t), [a], eps=eps)
for dtype in tensor.float_dtypes:
a = sp.csr_matrix(self.properties, dtype=dtype)
verify_grad_sparse(S2.Cast('float64'), [a])
for t in tensor.float_dtypes:
eps = None
if t == 'float32':
eps = 7e-4
a = sp.csr_matrix(self.properties, dtype=dtype)
verify_grad_sparse(S2.Cast(t), [a], eps=eps)
class HVStackTester(utt.InferShapeTester):
......@@ -161,11 +165,16 @@ class HVStackTester(utt.InferShapeTester):
def test_grad(self):
for format in sparse.sparse_formats:
for out_f in sparse.sparse_formats:
for dtype in ['float64']: # sparse.float_dtypes:
for dtype in sparse.float_dtypes:
eps = None
if dtype == 'float32':
eps = 7e-4
verify_grad_sparse(
self.op_class(format=out_f, dtype=dtype),
self.mat[format],
structured=False)
structured=False,
eps=eps)
def _hv_switch(op, expected_function):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论