提交 5236fff2 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

move more tests to deterministic interface

上级 c502818a
...@@ -60,11 +60,11 @@ class T_updates(unittest.TestCase): ...@@ -60,11 +60,11 @@ class T_updates(unittest.TestCase):
data = numpy.float32([1, 2, 3, 4]) data = numpy.float32([1, 2, 3, 4])
x = f32sc(data) x = f32sc(data)
y = x ** 2 y = x ** 2
f = theano.function([], y, updates={x: x + 1}) f = theano.function([], y, updates=[(x, x + 1)])
f() f()
# Test that we can update with a CudaVariable # Test that we can update with a CudaVariable
f = theano.function([], y, updates={x: cuda.gpu_from_host(x + 1)}) f = theano.function([], y, updates=[(x, cuda.gpu_from_host(x + 1))])
f() f()
def test_2(self): def test_2(self):
...@@ -74,7 +74,7 @@ class T_updates(unittest.TestCase): ...@@ -74,7 +74,7 @@ class T_updates(unittest.TestCase):
value=numpy.zeros((10, 10), 'float32')) value=numpy.zeros((10, 10), 'float32'))
x = tensor.fmatrix('x') x = tensor.fmatrix('x')
output_updates = {output_var: x ** 2} output_updates = [(output_var, x ** 2)]
output_givens = {x: data} output_givens = {x: data}
output_func = theano.function(inputs=[], outputs=[], output_func = theano.function(inputs=[], outputs=[],
updates=output_updates, givens=output_givens) updates=output_updates, givens=output_givens)
...@@ -89,8 +89,8 @@ class T_updates(unittest.TestCase): ...@@ -89,8 +89,8 @@ class T_updates(unittest.TestCase):
# the update_var has type matrix, and the update expression # the update_var has type matrix, and the update expression
# is a broadcasted scalar, and that should not be allowed. # is a broadcasted scalar, and that should not be allowed.
self.assertRaises(TypeError, theano.function, inputs=[], outputs=[], self.assertRaises(TypeError, theano.function, inputs=[], outputs=[],
updates={output_var: updates=[(output_var,
output_var.sum()}) output_var.sum())])
def test_err_broadcast(self): def test_err_broadcast(self):
# Test that we raise a good error message when we don't # Test that we raise a good error message when we don't
...@@ -101,8 +101,8 @@ class T_updates(unittest.TestCase): ...@@ -101,8 +101,8 @@ class T_updates(unittest.TestCase):
# the update_var has type matrix, and the update expression # the update_var has type matrix, and the update expression
# is a broadcasted scalar, and that should not be allowed. # is a broadcasted scalar, and that should not be allowed.
self.assertRaises(TypeError, theano.function, inputs=[], outputs=[], self.assertRaises(TypeError, theano.function, inputs=[], outputs=[],
updates={output_var: updates=[(output_var,
output_var.sum().dimshuffle('x', 'x')}) output_var.sum().dimshuffle('x', 'x'))])
def test_broadcast(self): def test_broadcast(self):
# Test that we can rebroadcast # Test that we can rebroadcast
......
...@@ -1219,7 +1219,7 @@ class UsmmTests(unittest.TestCase): ...@@ -1219,7 +1219,7 @@ class UsmmTests(unittest.TestCase):
mode = theano.compile.mode.get_default_mode().excluding('fusion') mode = theano.compile.mode.get_default_mode().excluding('fusion')
if inplace: if inplace:
updates = {z: z - a * theano.sparse.dot(x, y)} updates = [(z, z - a * theano.sparse.dot(x, y))]
f_a = theano.function([a, x, y], [], f_a = theano.function([a, x, y], [],
updates=updates, updates=updates,
mode=mode) mode=mode)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论