提交 4ca430a8 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

removed uses of the warn_type flag

上级 378430eb
......@@ -16,11 +16,8 @@ from theano.gof.null_type import NullType
one = theano.tensor.as_tensor_variable(1.)
def _grad_sources_inputs(*args):
# warn_type was introduced after this code, it complains throughout for nothing.
return grad_sources_inputs(warn_type=False, *args)
class test_grad_sources_inputs(unittest.TestCase):
class testgrad_sources_inputs(unittest.TestCase):
def test_retNone1(self):
"""Test that it is not ok to return None from op.grad()"""
......@@ -35,7 +32,7 @@ class test_grad_sources_inputs(unittest.TestCase):
pass
a = retNone().make_node()
try:
_grad_sources_inputs([(a.out, one)], None)
grad_sources_inputs([(a.out, one)], None)
except TypeError, e:
return
self.fail()
......@@ -52,10 +49,10 @@ class test_grad_sources_inputs(unittest.TestCase):
i = theano.tensor.vector()
j = theano.tensor.vector()
a1 = retOne().make_node(i)
g = _grad_sources_inputs([(a1.out, one)], None)
g = grad_sources_inputs([(a1.out, one)], None)
a2 = retOne().make_node(i,j)
try:
g = _grad_sources_inputs([(a2.out, one)], None)
g = grad_sources_inputs([(a2.out, one)], None)
except ValueError, e:
return
self.fail()
......@@ -71,7 +68,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, inp, grads):
return gval,
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], one)], None)
g = grad_sources_inputs([(a1.outputs[0], one)], None)
self.assertTrue(g[a1.inputs[0]] is gval)
def test_1in_Nout(self):
......@@ -87,7 +84,7 @@ class test_grad_sources_inputs(unittest.TestCase):
gz1, gz2 = grads
return gval,
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], one)], None)
g = grad_sources_inputs([(a1.outputs[0], one)], None)
self.assertTrue(g[a1.inputs[0]] is gval)
def test_Nin_1out(self):
......@@ -104,7 +101,7 @@ class test_grad_sources_inputs(unittest.TestCase):
gz, = grads
return (gval0, gval1)
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], one)], None)
g = grad_sources_inputs([(a1.outputs[0], one)], None)
self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1)
......@@ -120,7 +117,7 @@ class test_grad_sources_inputs(unittest.TestCase):
def grad(self, inp, grads):
return gval0, gval1
a1 = O().make_node()
g = _grad_sources_inputs([(a1.outputs[0], one)], None)
g = grad_sources_inputs([(a1.outputs[0], one)], None)
self.assertTrue(g[a1.inputs[0]] is gval0)
self.assertTrue(g[a1.inputs[1]] is gval1)
......@@ -136,7 +133,7 @@ class test_grad_sources_inputs(unittest.TestCase):
return [one]
i = theano.tensor.matrix()
a1 = O(self).make_node(i)
g = grad_sources_inputs([(a1.outputs[0], one)], None, warn_type=False)
g = grad_sources_inputs([(a1.outputs[0], one)], None)
self.assertTrue(g[i] is one)
def test_unimplemented_grad_func():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论