提交 a455eb67 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

A few fixes to prepare for change in grad output

I caught places where the warning about using grad w.r.t. a one-element list was raised when running the test-suite. This commit removes these warnings (ensuring changing the default value of keep_wrt_type to True will not affect tests). Note that I changed the behavior of 'test_dot_w_self'. I asked the author (James) before doing so: it turns out it was not behaving as intended before.
上级 17cf3dc5
...@@ -632,7 +632,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No ...@@ -632,7 +632,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
g_cost = cast(g_cost, o_output.dtype) g_cost = cast(g_cost, o_output.dtype)
symbolic_grad = grad(cost, tensor_pt, g_cost, symbolic_grad = grad(cost, tensor_pt, g_cost,
disconnected_inputs='ignore') disconnected_inputs='ignore',
keep_wrt_type=True)
#if o_output.dtype in ['float32','float64']: #if o_output.dtype in ['float32','float64']:
# assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp)) # assert all([x.dtype == o_output.dtype for x in symbolic_grad]),("Expected grad of type %s, got %s "%( symbolic_grad.dtype, o_output.dtyp))
...@@ -644,8 +645,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No ...@@ -644,8 +645,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
analytic_grad = grad_fn(*[p.copy() for p in pt]) analytic_grad = grad_fn(*[p.copy() for p in pt])
if not isinstance(analytic_grad, (list, tuple)): # Since `tensor_pt` is a list, `analytic_grad` should be one too.
analytic_grad = [analytic_grad] assert isinstance(analytic_grad, list)
max_arg, max_err_pos, max_abs_err, max_rel_err =\ max_arg, max_err_pos, max_abs_err, max_rel_err =\
num_grad.max_err(analytic_grad, abs_tol, rel_tol) num_grad.max_err(analytic_grad, abs_tol, rel_tol)
......
...@@ -808,13 +808,13 @@ def test_dot_w_self(): ...@@ -808,13 +808,13 @@ def test_dot_w_self():
# This can trigger problems in the optimization because what would normally be a gemm must # This can trigger problems in the optimization because what would normally be a gemm must
# not be because the output is aliased to one of the inputs. # not be because the output is aliased to one of the inputs.
A = shared(value = numpy.ones((2,2))) A = shared(value=numpy.ones((2,2)))
B = T.matrix() B = T.matrix()
p = T.dot(A,A)*B p = T.dot(A,A)*B
grad = T.grad(T.mean(p),[A]) grad = T.grad(T.mean(p), A)
f = theano.function([B], p, updates = { A : A - grad[0]} ) f = theano.function([B], p, updates={A : A - grad})
# tests correctness in debugmode # tests correctness in debugmode
f(numpy.asarray([[0,1], [2,3]], dtype=config.floatX)) f(numpy.asarray([[0,1], [2,3]], dtype=config.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论