提交 57e3a562 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: simplify tests with assertRaises

上级 354b602c
import sys import sys
import itertools import itertools
import unittest
import numpy import numpy
import theano import theano
from theano import tensor from theano import tensor
...@@ -21,7 +22,7 @@ def test_equal_compuations(): ...@@ -21,7 +22,7 @@ def test_equal_compuations():
# map_variables # # map_variables #
################# #################
class TestMapVariables(object): class TestMapVariables(unittest.TestCase):
@staticmethod @staticmethod
def replacer(graph): def replacer(graph):
return getattr(graph.tag, "replacement", graph) return getattr(graph.tag, "replacement", graph)
...@@ -104,15 +105,8 @@ class TestMapVariables(object): ...@@ -104,15 +105,8 @@ class TestMapVariables(object):
s, _ = theano.scan(step, sequences=x, s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)]) outputs_info=[numpy.array(0.)])
try: self.assertRaises(NotImplementedError,
s2, = map_variables(self.replacer, [s]) map_variables, self.replacer, [s])
except NotImplementedError as e:
e = sys.exc_info()[1]
assert("introduces shared variable" in str(e))
return
# test failed
return 0
def test_scan_with_shared_update2(self): def test_scan_with_shared_update2(self):
x = tensor.vector('x') x = tensor.vector('x')
...@@ -135,15 +129,8 @@ class TestMapVariables(object): ...@@ -135,15 +129,8 @@ class TestMapVariables(object):
s, _ = theano.scan(step, sequences=x, s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)]) outputs_info=[numpy.array(0.)])
try: self.assertRaises(NotImplementedError,
s2, = map_variables(self.replacer, [s]) map_variables, self.replacer, [s])
except NotImplementedError as e:
e = sys.exc_info()[1]
assert("introduces shared variable" in str(e))
return
# test failed
return 0
def test_leaf_inside_scan(self): def test_leaf_inside_scan(self):
x = tensor.vector('x') x = tensor.vector('x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论