提交 57e3a562 authored 作者: Tim Cooijmans's avatar Tim Cooijmans

map_variables: simplify tests with assertRaises

上级 354b602c
import sys
import itertools
import unittest
import numpy
import theano
from theano import tensor
......@@ -21,7 +22,7 @@ def test_equal_compuations():
# map_variables #
#################
class TestMapVariables(object):
class TestMapVariables(unittest.TestCase):
@staticmethod
def replacer(graph):
return getattr(graph.tag, "replacement", graph)
......@@ -104,15 +105,8 @@ class TestMapVariables(object):
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
try:
s2, = map_variables(self.replacer, [s])
except NotImplementedError as e:
e = sys.exc_info()[1]
assert("introduces shared variable" in str(e))
return
# test failed
return 0
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [s])
def test_scan_with_shared_update2(self):
x = tensor.vector('x')
......@@ -135,15 +129,8 @@ class TestMapVariables(object):
s, _ = theano.scan(step, sequences=x,
outputs_info=[numpy.array(0.)])
try:
s2, = map_variables(self.replacer, [s])
except NotImplementedError as e:
e = sys.exc_info()[1]
assert("introduces shared variable" in str(e))
return
# test failed
return 0
self.assertRaises(NotImplementedError,
map_variables, self.replacer, [s])
def test_leaf_inside_scan(self):
x = tensor.vector('x')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论