提交 3de75d70 authored 作者: Frederic's avatar Frederic

In verify_grad, cast the projection to float32.

This help gpuarray tests to run on device that don't support float64.
上级 e7cb39cb
......@@ -1492,7 +1492,7 @@ class numeric_grad(object):
def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
out_type=None, abs_tol=None,
rel_tol=None, mode=None, cast_to_output_type=False):
rel_tol=None, mode=None, cast_to_output_type=True):
"""Test a gradient by Finite Difference Method. Raise error on failure.
Example:
......@@ -1525,6 +1525,9 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
comparison
:param rel_tol: relative tolerance used as threshold for gradient
comparison
:param cast_to_output_type: if the output is float32 and
cast_to_output_type is True, cast the random projection to
float32. Otherwise it is float64.
:note: WARNING to unit-test writers: if `op` is a function that builds
a graph, try to make it a SMALL graph. Often verify grad is run
......@@ -1602,13 +1605,13 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
# random_projection should not have elements too small,
# otherwise too much precision is lost in numerical gradient
def random_projection(dtype):
def random_projection():
plain = rng.rand(*o_fn_out.shape) + 0.5
if cast_to_output_type:
if cast_to_output_type and o_output.dtype == "float32":
return numpy.array(plain, o_output.dtype)
return plain.astype(dtype)
dtype = "float32" if all([p.dtype == 'float32' for p in pt]) else "float64"
t_r = shared(random_projection(dtype))
return plain
t_r = shared(random_projection())
t_r.name = 'random_projection'
# random projection of o onto t_r
......@@ -1643,7 +1646,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
# get new random projection for next test
if test_num < n_tests - 1:
t_r.set_value(random_projection(t_r.dtype), borrow=True)
t_r.set_value(random_projection(), borrow=True)
except Exception, e:
e.args += ("\nThe error happened with the following inputs:", pt,
"\nThe value of eps is:", eps,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论