提交 27f80b39 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Refactor tests.gof.test_lazy and enable an incorrectly named test

上级 50ab0430
from copy import deepcopy import pytest
import numpy as np import numpy as np
import theano import theano
from theano.gof.op import PureOp import theano.tensor as tt
from theano.gof import Apply, generic
from copy import deepcopy
from theano import function, Mode from theano import function, Mode
from theano.gof import Apply, generic
from theano.gof.op import PureOp
from theano.ifelse import ifelse from theano.ifelse import ifelse
import theano.tensor as tt
class IfElseIfElseIf(PureOp): class IfElseIfElseIf(PureOp):
...@@ -88,16 +91,17 @@ class IfElseIfElseIf(PureOp): ...@@ -88,16 +91,17 @@ class IfElseIfElseIf(PureOp):
return thunk return thunk
class NotImplementedOp(PureOp): class NotImplementedOpException(Exception):
class E(Exception): pass
pass
class NotImplementedOp(PureOp):
def make_node(self, x): def make_node(self, x):
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
def make_thunk(self, node, storage_map, compute_map, no_recycling, impl): def make_thunk(self, node, storage_map, compute_map, no_recycling, impl):
def thunk(): def thunk():
raise self.E() raise NotImplementedOpException()
thunk.lazy = False thunk.lazy = False
return thunk return thunk
...@@ -109,13 +113,17 @@ def test_ifelse(): ...@@ -109,13 +113,17 @@ def test_ifelse():
c = generic() c = generic()
notimpl = NotImplementedOp() notimpl = NotImplementedOp()
lazys = [True] lazys = [True]
# We need lazy to end up being True for this test. # We need lazy to end up being True for this test.
if theano.config.vm.lazy in [True, None]: if theano.config.vm.lazy in [True, None]:
lazys = [True, None] lazys = [True, None]
cloops = [True, False] cloops = [True, False]
if theano.config.cxx == "": if theano.config.cxx == "":
cloops = [False] cloops = [False]
for cloop in cloops: for cloop in cloops:
for lazy in lazys: for lazy in lazys:
linker = theano.gof.vm.VM_Linker(use_cloop=cloop, lazy=lazy) linker = theano.gof.vm.VM_Linker(use_cloop=cloop, lazy=lazy)
...@@ -125,21 +133,13 @@ def test_ifelse(): ...@@ -125,21 +133,13 @@ def test_ifelse():
mode=Mode(linker=linker, optimizer="fast_run"), mode=Mode(linker=linker, optimizer="fast_run"),
) )
try: with pytest.raises(NotImplementedOpException):
# print "case 1"
f(1, "a", "b") f(1, "a", "b")
assert False
except NotImplementedOp.E:
pass
# print "... passed"
# print "case 2"
# print f(0, 'a', 'b')
assert f(0, "a", "b") == "b" assert f(0, "a", "b") == "b"
# print "... passed"
def more_complex_test(): def test_nested():
notimpl = NotImplementedOp() notimpl = NotImplementedOp()
ifelseifelseif = IfElseIfElseIf() ifelseifelseif = IfElseIfElseIf()
...@@ -156,18 +156,11 @@ def more_complex_test(): ...@@ -156,18 +156,11 @@ def more_complex_test():
t4 = ifelseifelseif(tt.eq(x1, x2), x1, tt.eq(x1, 5), x2, c2, t3, t3 + 0.5) t4 = ifelseifelseif(tt.eq(x1, x2), x1, tt.eq(x1, 5), x2, c2, t3, t3 + 0.5)
t4.name = "t4" t4.name = "t4"
f = function([c1, c2, x1, x2], t4, mode=Mode(linker="vm", optimizer="fast_run")) linker = theano.gof.vm.VM_Linker(lazy=False)
if theano.config.vm.lazy is False: f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run"))
try: with pytest.raises(NotImplementedOpException):
f(1, 0, np.array(10, dtype=x1.dtype), 0) f(1, 0, np.array(10, dtype=x1.dtype), 0)
assert False
except NotImplementedOp.E: linker = theano.gof.vm.VM_Linker(lazy=True)
pass f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run"))
else: assert f(1, 0, np.array(10, dtype=x1.dtype), 0) == 20.5
print(f(1, 0, np.array(10, dtype=x1.dtype), 0))
assert f(1, 0, np.array(10, dtype=x1.dtype), 0) == 20.5
print("... passed")
if __name__ == "__main__":
more_complex_test()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论