提交 854e32c6 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

removed heavily numpy-dependent test

上级 8b465d41
...@@ -545,6 +545,10 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -545,6 +545,10 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
output_grads = [ access_grad_cache(var) for var in node.outputs ] output_grads = [ access_grad_cache(var) for var in node.outputs ]
input_grads = node.op.grad(inputs, output_grads) input_grads = node.op.grad(inputs, output_grads)
if input_grads is None:
raise TypeError("%s.grad returned NoneType, "
"expected iterable." % str(node.op))
if len(input_grads) != len(inputs): if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\ raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op)) "terms.") % str(node.op))
...@@ -675,6 +679,10 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'): ...@@ -675,6 +679,10 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
input_grads = node.op.grad(node.inputs, input_grads = node.op.grad(node.inputs,
[access_grad_cache(var) for var in node.outputs]) [access_grad_cache(var) for var in node.outputs])
if input_grads is None:
raise TypeError("%s.grad returned NoneType, "
"expected iterable." % str(node.op))
if len(input_grads) != len(inputs): if len(input_grads) != len(inputs):
raise ValueError(("%s returned the wrong number of gradient"+\ raise ValueError(("%s returned the wrong number of gradient"+\
"terms.") % str(node.op)) "terms.") % str(node.op))
......
...@@ -33,8 +33,7 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -33,8 +33,7 @@ class test_grad_sources_inputs(unittest.TestCase):
a = retNone().make_node() a = retNone().make_node()
try: try:
_grad_sources_inputs([(a.out, one)], None) _grad_sources_inputs([(a.out, one)], None)
except ValueError, e: except TypeError, e:
self.assertTrue(e[0] is gradient._msg_retType)
return return
self.fail() self.fail()
...@@ -137,38 +136,6 @@ class test_grad_sources_inputs(unittest.TestCase): ...@@ -137,38 +136,6 @@ class test_grad_sources_inputs(unittest.TestCase):
g = grad_sources_inputs([(a1.outputs[0], one)], None, warn_type=False) g = grad_sources_inputs([(a1.outputs[0], one)], None, warn_type=False)
self.assertTrue(g[i] is one) self.assertTrue(g[i] is one)
def test_inputs(self):
"""Test that passing inputs shortens the traversal"""
class O(gof.op.Op):
def __init__(self, tst, grad_ok):
self.tst = tst
self.grad_ok = grad_ok
def make_node(self, *inputs):
outputs = [theano.tensor.matrix(),theano.tensor.matrix()]
return gof.Apply(self, inputs, outputs)
def grad(self, inputs, grads):
g0, g1 = grads
if not self.grad_ok:
self.tst.fail()
else:
if g1:
return [g0, g0+g1]
else:
return [g0, g0]
i = theano.tensor.matrix()
j = theano.tensor.matrix()
k = theano.tensor.matrix()
a1 = O(self, True).make_node(i,j)
a2 = O(self, True).make_node(k,a1.outputs[1])
g = _grad_sources_inputs([(a2.outputs[0], one), (a1.outputs[1],4),
(a1.outputs[0], 3), (a1.outputs[0], 3)], a1.outputs)
self.assertTrue(g[a2.inputs[0]] == 1)
self.assertTrue(g[a2.inputs[1]] == 5)
self.assertTrue(g[a1.outputs[0]] == 6)
self.assertTrue(g[a1.outputs[1]] == 5)
self.assertTrue(a1.inputs[0] not in g)
self.assertTrue(a1.inputs[1] not in g)
def test_unimplemented_grad_func(): def test_unimplemented_grad_func():
#tests that function compilation catches unimplemented grads in the graph #tests that function compilation catches unimplemented grads in the graph
a = theano.tensor.vector() a = theano.tensor.vector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论