提交 3de75d70 authored 作者: Frederic's avatar Frederic

In verify_grad, cast the projection to float32.

This help gpuarray tests to run on device that don't support float64.
上级 e7cb39cb
...@@ -1492,7 +1492,7 @@ class numeric_grad(object): ...@@ -1492,7 +1492,7 @@ class numeric_grad(object):
def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
out_type=None, abs_tol=None, out_type=None, abs_tol=None,
rel_tol=None, mode=None, cast_to_output_type=False): rel_tol=None, mode=None, cast_to_output_type=True):
"""Test a gradient by Finite Difference Method. Raise error on failure. """Test a gradient by Finite Difference Method. Raise error on failure.
Example: Example:
...@@ -1525,6 +1525,9 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1525,6 +1525,9 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
comparison comparison
:param rel_tol: relative tolerance used as threshold for gradient :param rel_tol: relative tolerance used as threshold for gradient
comparison comparison
:param cast_to_output_type: if the output is float32 and
cast_to_output_type is True, cast the random projection to
float32. Otherwise it is float64.
:note: WARNING to unit-test writers: if `op` is a function that builds :note: WARNING to unit-test writers: if `op` is a function that builds
a graph, try to make it a SMALL graph. Often verify grad is run a graph, try to make it a SMALL graph. Often verify grad is run
...@@ -1602,13 +1605,13 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1602,13 +1605,13 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
# random_projection should not have elements too small, # random_projection should not have elements too small,
# otherwise too much precision is lost in numerical gradient # otherwise too much precision is lost in numerical gradient
def random_projection(dtype): def random_projection():
plain = rng.rand(*o_fn_out.shape) + 0.5 plain = rng.rand(*o_fn_out.shape) + 0.5
if cast_to_output_type: if cast_to_output_type and o_output.dtype == "float32":
return numpy.array(plain, o_output.dtype) return numpy.array(plain, o_output.dtype)
return plain.astype(dtype) return plain
dtype = "float32" if all([p.dtype == 'float32' for p in pt]) else "float64"
t_r = shared(random_projection(dtype)) t_r = shared(random_projection())
t_r.name = 'random_projection' t_r.name = 'random_projection'
# random projection of o onto t_r # random projection of o onto t_r
...@@ -1643,7 +1646,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1643,7 +1646,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
# get new random projection for next test # get new random projection for next test
if test_num < n_tests - 1: if test_num < n_tests - 1:
t_r.set_value(random_projection(t_r.dtype), borrow=True) t_r.set_value(random_projection(), borrow=True)
except Exception, e: except Exception, e:
e.args += ("\nThe error happened with the following inputs:", pt, e.args += ("\nThe error happened with the following inputs:", pt,
"\nThe value of eps is:", eps, "\nThe value of eps is:", eps,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论